"""Output registration system for pipeline steps."""
# Authors: The Lightweight Pipeline developers
# SPDX-License-Identifier: BSD-3-Clause
import functools
from fnmatch import fnmatch
# Class-level storage for registered outputs (populated at decoration time)
_CLASS_OUTPUTS = {}
[docs]
def register_output(
name,
description="",
enabled_by_default=True,
group=None,
check_exists=False,
extension=None,
suffix=None,
use_bids_structure=None,
custom_dir=None,
**extra_path_params,
):
"""
Register a method as an optional output generator.
This allows steps to define multiple outputs that can be selectively
generated based on CLI arguments or config settings.
Parameters
----------
name : str
Name of the output (used for CLI selection, e.g., "plot", "stats").
description : str, optional
Human-readable description of the output.
enabled_by_default : bool, optional
Whether this output is generated by default. Default is True.
group : str, optional
Optional group name for categorization (future use).
check_exists : bool, optional
If True, check if output file exists before running the method.
Respects overwrite_mode setting. Prevents expensive computations
when output already exists. Default is False.
extension : str, optional
Default file extension (e.g., '.png', '.csv'). Used for both
existence checking and as default when saving.
suffix : str, optional
Default BIDS suffix. Used for both checking and saving.
use_bids_structure : bool, optional
Default for BIDS path structure.
custom_dir : str or Path, optional
Default custom output directory.
**extra_path_params : dict
Additional default path parameters (e.g., datatype, processing).
Returns
-------
callable
Decorated function.
Examples
--------
>>> class MyStep(Pipeline_Step):
... @register_output(
... "expensive_plot",
... "Channel visualization",
... check_exists=True,
... extension=".png"
... )
... def generate_plot(self):
... # Skipped if file exists and overwrite_mode='never'
... data = expensive_computation()
... fig = create_plot(data)
... # extension=".png" used automatically
... self.output_manager.save_figure(fig, "expensive_plot")
"""
def decorator(func):
# Store registration info as function attributes (backward compatibility)
func._is_registered_output = True
func._output_name = name
func._output_description = description
func._output_enabled_by_default = enabled_by_default
func._output_group = group
func._check_exists = check_exists
# Store default path parameters (remove None values)
func._default_path_params = {
k: v
for k, v in {
"extension": extension,
"suffix": suffix,
"use_bids_structure": use_bids_structure,
"custom_dir": custom_dir,
**extra_path_params,
}.items()
if v is not None
}
# Store in class-level registry for efficient lookup
# We'll update this when the method is bound to a class
func._output_info = {
"name": name,
"description": description,
"enabled_by_default": enabled_by_default,
"group": group,
"check_exists": check_exists,
"default_path_params": func._default_path_params,
"method_name": func.__name__,
}
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
# Check if output should be generated (config/CLI)
if not self.should_generate_output(name):
return None
# Check file existence if requested
if check_exists and func._default_path_params:
try:
# Get output path using default parameters
output_path = self.get_output_path(
name, **func._default_path_params
)
# Check if we should overwrite
if not self.output_manager._should_overwrite(output_path):
return None
except Exception as e:
# If path check fails, log warning and proceed
logger = getattr(self.config, "logger", None)
if logger:
logger.warning(
f"Could not check existence for '{name}': {e}. Proceeding."
)
return func(self, *args, **kwargs)
return wrapper
return decorator
[docs]
class Output_Registry:
"""
Registry for managing registered outputs in a pipeline step.
This class is used internally by Pipeline_Step to track and filter
registered outputs.
"""
[docs]
def __init__(self, step_instance):
"""
Initialize the registry for a pipeline step.
Parameters
----------
step_instance : Pipeline_Step
The step instance to scan for registered outputs.
"""
self.step = step_instance
self._registered = {}
self._scan_step_for_outputs()
def _scan_step_for_outputs(self):
"""
Scan the step class for methods decorated with @register_output.
Uses class-level introspection (much faster than dir()) to find
registered outputs. Only scans the MRO (Method Resolution Order)
which is more efficient than full dir() introspection.
"""
# Get the class (not instance) for more efficient introspection
step_class = self.step.__class__
# Scan the Method Resolution Order (MRO) - more efficient than dir()
# This finds all methods in the class hierarchy
for cls in step_class.__mro__:
# Skip object base class
if cls is object:
continue
# Only scan class __dict__ (direct attributes), not inherited
# This is much faster than dir() which walks the entire hierarchy
for attr_name, attr_value in cls.__dict__.items():
# Skip private attributes and output_registry to avoid recursion
if attr_name.startswith("_") or attr_name == "output_registry":
continue
# Check if it's a registered output by looking at stored metadata
if hasattr(attr_value, "_is_registered_output"):
# Get the bound method from the instance
bound_method = getattr(self.step, attr_name)
# Get stored metadata
output_info = {
"name": attr_value._output_name,
"description": attr_value._output_description,
"enabled_by_default": attr_value._output_enabled_by_default,
"group": attr_value._output_group,
"method": bound_method,
"check_exists": attr_value._check_exists,
"default_path_params": attr_value._default_path_params,
}
# Only add if not yet reg. (subclass overrides take precedence)
if attr_value._output_name not in self._registered:
self._registered[attr_value._output_name] = output_info
[docs]
def get_all(self):
"""
Get all registered outputs.
Returns
-------
dict
Dictionary mapping output names to their info.
"""
return self._registered.copy()
[docs]
def get_enabled_by_default(self):
"""
Get outputs that are enabled by default.
Returns
-------
list
List of output names enabled by default.
"""
return [
name
for name, info in self._registered.items()
if info["enabled_by_default"]
]
[docs]
def should_generate(self, output_name, config):
"""
Determine if an output should be generated.
Parameters
----------
output_name : str
Name of the output to check.
config : Config
Configuration object with outputs_to_generate and outputs_to_skip settings.
Returns
-------
bool
True if output should be generated.
"""
# Check if output is registered
if output_name not in self._registered:
return False
# Check if output is explicitly skipped (takes precedence)
outputs_skip = getattr(config, "outputs_to_skip", None)
if outputs_skip is not None:
if self._matches_patterns(output_name, outputs_skip):
return False
# Get configuration for outputs to generate
outputs_config = getattr(config, "outputs_to_generate", None)
# If no config specified, use default behavior
if outputs_config is None:
return self._registered[output_name]["enabled_by_default"]
# Check if output matches the patterns
return self._matches_patterns(output_name, outputs_config)
def _matches_patterns(self, output_name, patterns_config):
"""
Check if an output name matches any pattern in the config.
Parameters
----------
output_name : str
Name of the output to check.
patterns_config : dict or list
Configuration with patterns (dict for step-scoped, list for global).
Returns
-------
bool
True if output matches any pattern.
"""
# Check if patterns_config is dict (step-scoped) or list (global)
if isinstance(patterns_config, dict):
step_id = self.step.short_id
# Check for exact step match
patterns = patterns_config.get(step_id, None)
# If no exact match, check for wildcard "*" (applies to all steps)
if patterns is None and "*" in patterns_config:
patterns = patterns_config["*"]
# If this step not in config, return False (doesn't match)
if patterns is None:
return False
# Check if output name matches any pattern
return any(fnmatch(output_name, pattern) for pattern in patterns)
elif isinstance(patterns_config, list):
# Global patterns apply to all steps
return any(fnmatch(output_name, pattern) for pattern in patterns_config)
else:
# Unknown config format, return False
return False
[docs]
def list_outputs(self, include_disabled=True):
"""
List all registered outputs with their information.
Parameters
----------
include_disabled : bool, optional
Include outputs that are disabled by default. Default is True.
Returns
-------
list
List of tuples (name, description, enabled_by_default).
"""
outputs = []
for name, info in self._registered.items():
if include_disabled or info["enabled_by_default"]:
outputs.append((name, info["description"], info["enabled_by_default"]))
return outputs