Source code for lw_pipeline.output_registration

"""Output registration system for pipeline steps."""

# Authors: The Lightweight Pipeline developers
# SPDX-License-Identifier: BSD-3-Clause

import functools
from fnmatch import fnmatch

# Class-level storage for registered outputs (populated at decoration time)
_CLASS_OUTPUTS = {}


[docs] def register_output( name, description="", enabled_by_default=True, group=None, check_exists=False, extension=None, suffix=None, use_bids_structure=None, custom_dir=None, **extra_path_params, ): """ Register a method as an optional output generator. This allows steps to define multiple outputs that can be selectively generated based on CLI arguments or config settings. Parameters ---------- name : str Name of the output (used for CLI selection, e.g., "plot", "stats"). description : str, optional Human-readable description of the output. enabled_by_default : bool, optional Whether this output is generated by default. Default is True. group : str, optional Optional group name for categorization (future use). check_exists : bool, optional If True, check if output file exists before running the method. Respects overwrite_mode setting. Prevents expensive computations when output already exists. Default is False. extension : str, optional Default file extension (e.g., '.png', '.csv'). Used for both existence checking and as default when saving. suffix : str, optional Default BIDS suffix. Used for both checking and saving. use_bids_structure : bool, optional Default for BIDS path structure. custom_dir : str or Path, optional Default custom output directory. **extra_path_params : dict Additional default path parameters (e.g., datatype, processing). Returns ------- callable Decorated function. Examples -------- >>> class MyStep(Pipeline_Step): ... @register_output( ... "expensive_plot", ... "Channel visualization", ... check_exists=True, ... extension=".png" ... ) ... def generate_plot(self): ... # Skipped if file exists and overwrite_mode='never' ... data = expensive_computation() ... fig = create_plot(data) ... # extension=".png" used automatically ... self.output_manager.save_figure(fig, "expensive_plot") """ def decorator(func): # Store registration info as function attributes (backward compatibility) func._is_registered_output = True func._output_name = name func._output_description = description func._output_enabled_by_default = enabled_by_default func._output_group = group func._check_exists = check_exists # Store default path parameters (remove None values) func._default_path_params = { k: v for k, v in { "extension": extension, "suffix": suffix, "use_bids_structure": use_bids_structure, "custom_dir": custom_dir, **extra_path_params, }.items() if v is not None } # Store in class-level registry for efficient lookup # We'll update this when the method is bound to a class func._output_info = { "name": name, "description": description, "enabled_by_default": enabled_by_default, "group": group, "check_exists": check_exists, "default_path_params": func._default_path_params, "method_name": func.__name__, } @functools.wraps(func) def wrapper(self, *args, **kwargs): # Check if output should be generated (config/CLI) if not self.should_generate_output(name): return None # Check file existence if requested if check_exists and func._default_path_params: try: # Get output path using default parameters output_path = self.get_output_path( name, **func._default_path_params ) # Check if we should overwrite if not self.output_manager._should_overwrite(output_path): return None except Exception as e: # If path check fails, log warning and proceed logger = getattr(self.config, "logger", None) if logger: logger.warning( f"Could not check existence for '{name}': {e}. Proceeding." ) return func(self, *args, **kwargs) return wrapper return decorator
[docs] class Output_Registry: """ Registry for managing registered outputs in a pipeline step. This class is used internally by Pipeline_Step to track and filter registered outputs. """
[docs] def __init__(self, step_instance): """ Initialize the registry for a pipeline step. Parameters ---------- step_instance : Pipeline_Step The step instance to scan for registered outputs. """ self.step = step_instance self._registered = {} self._scan_step_for_outputs()
def _scan_step_for_outputs(self): """ Scan the step class for methods decorated with @register_output. Uses class-level introspection (much faster than dir()) to find registered outputs. Only scans the MRO (Method Resolution Order) which is more efficient than full dir() introspection. """ # Get the class (not instance) for more efficient introspection step_class = self.step.__class__ # Scan the Method Resolution Order (MRO) - more efficient than dir() # This finds all methods in the class hierarchy for cls in step_class.__mro__: # Skip object base class if cls is object: continue # Only scan class __dict__ (direct attributes), not inherited # This is much faster than dir() which walks the entire hierarchy for attr_name, attr_value in cls.__dict__.items(): # Skip private attributes and output_registry to avoid recursion if attr_name.startswith("_") or attr_name == "output_registry": continue # Check if it's a registered output by looking at stored metadata if hasattr(attr_value, "_is_registered_output"): # Get the bound method from the instance bound_method = getattr(self.step, attr_name) # Get stored metadata output_info = { "name": attr_value._output_name, "description": attr_value._output_description, "enabled_by_default": attr_value._output_enabled_by_default, "group": attr_value._output_group, "method": bound_method, "check_exists": attr_value._check_exists, "default_path_params": attr_value._default_path_params, } # Only add if not yet reg. (subclass overrides take precedence) if attr_value._output_name not in self._registered: self._registered[attr_value._output_name] = output_info
[docs] def get_all(self): """ Get all registered outputs. Returns ------- dict Dictionary mapping output names to their info. """ return self._registered.copy()
[docs] def get_enabled_by_default(self): """ Get outputs that are enabled by default. Returns ------- list List of output names enabled by default. """ return [ name for name, info in self._registered.items() if info["enabled_by_default"] ]
[docs] def should_generate(self, output_name, config): """ Determine if an output should be generated. Parameters ---------- output_name : str Name of the output to check. config : Config Configuration object with outputs_to_generate and outputs_to_skip settings. Returns ------- bool True if output should be generated. """ # Check if output is registered if output_name not in self._registered: return False # Check if output is explicitly skipped (takes precedence) outputs_skip = getattr(config, "outputs_to_skip", None) if outputs_skip is not None: if self._matches_patterns(output_name, outputs_skip): return False # Get configuration for outputs to generate outputs_config = getattr(config, "outputs_to_generate", None) # If no config specified, use default behavior if outputs_config is None: return self._registered[output_name]["enabled_by_default"] # Check if output matches the patterns return self._matches_patterns(output_name, outputs_config)
def _matches_patterns(self, output_name, patterns_config): """ Check if an output name matches any pattern in the config. Parameters ---------- output_name : str Name of the output to check. patterns_config : dict or list Configuration with patterns (dict for step-scoped, list for global). Returns ------- bool True if output matches any pattern. """ # Check if patterns_config is dict (step-scoped) or list (global) if isinstance(patterns_config, dict): step_id = self.step.short_id # Check for exact step match patterns = patterns_config.get(step_id, None) # If no exact match, check for wildcard "*" (applies to all steps) if patterns is None and "*" in patterns_config: patterns = patterns_config["*"] # If this step not in config, return False (doesn't match) if patterns is None: return False # Check if output name matches any pattern return any(fnmatch(output_name, pattern) for pattern in patterns) elif isinstance(patterns_config, list): # Global patterns apply to all steps return any(fnmatch(output_name, pattern) for pattern in patterns_config) else: # Unknown config format, return False return False
[docs] def list_outputs(self, include_disabled=True): """ List all registered outputs with their information. Parameters ---------- include_disabled : bool, optional Include outputs that are disabled by default. Default is True. Returns ------- list List of tuples (name, description, enabled_by_default). """ outputs = [] for name, info in self._registered.items(): if include_disabled or info["enabled_by_default"]: outputs.append((name, info["description"], info["enabled_by_default"])) return outputs