Source code for lw_pipeline.helper.mne

"""MNE helpers."""

# Authors: The Lightweight Pipeline developers
# SPDX-License-Identifier: BSD-3-Clause

import warnings

import mne
from mne_bids import BIDSPath, find_matching_paths, read_raw_bids


[docs] def raw_from_source(source, n_jobs = 1, suppress_runtime_warning=True, **kwargs): """ Produce an mne raw object from different sources. Possible sources: (lists will be concatenated) - BIDSPath - list of BIDSPath - mne.io.BaseRaw - list of mne.io.BaseRaw Parameters ---------- source : BIDSPath | list of BIDSPath | mne.io.BaseRaw | list of mne.io.BaseRaw The source to read the raw data from. suppress_runtime_warning : bool If True, suppresses RuntimeWarning when reading raw data. This is useful when reading raw data from BIDSPath, as it may raise a RuntimeWarning if coordsystem/electrode data, etc. is not found. kwargs : dict Additional keyword arguments to pass to mne.io.read_raw. Returns ------- raw : mne.io.BaseRaw The raw data object. Raises ------ ValueError If the source is not a BIDSPath, list of BIDSPath, mne.io.BaseRaw, or list of mne.io.BaseRaw. """ # check if the source_file is an instance of BIDSPath if isinstance(source, BIDSPath): try: if suppress_runtime_warning: with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) raw = read_raw_bids(source, **kwargs) else: raw = read_raw_bids(source, **kwargs) except Exception: raw = read_raw_bids(source, {"encoding": "latin1"}) # check if it is a list of bids paths elif isinstance(source, list) and all( [isinstance(fpath, BIDSPath) for fpath in source] ): try: if suppress_runtime_warning: with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) raws = [read_raw_bids(fpath, **kwargs) for fpath in source] else: raws = [read_raw_bids(fpath, **kwargs) for fpath in source] except Exception: raws = [read_raw_bids(fpath, {"encoding": "latin1"}) for fpath in source] # Check if all raws have the same sampling frequency sfreqs = [raw.info["sfreq"] for raw in raws] if len(set(sfreqs)) > 1: # If sampling frequencies differ, resample to the lowest frequency min_sfreq = min(sfreqs) print("Resampling to: ", min_sfreq) for raw in raws: if raw.info["sfreq"] != min_sfreq: print("Resampling:", raw.info["file_id"]) raw.resample(min_sfreq, n_jobs=n_jobs) raw = mne.io.concatenate_raws(raws) elif isinstance(source, mne.io.BaseRaw): raw = source elif isinstance(source, list) and all( [isinstance(raw, mne.io.BaseRaw) for raw in source] ): raw = mne.io.concatenate_raws(source) else: raise ValueError(f"Unknown source type {type(source)}.") return raw
[docs] def find_ch(identifier, ch_names, return_identifiers=False, error_level="ignore"): """ Find actual channel name by identifier (e.g. Fp1 for "EEG Fp1"). Parameters ---------- identifier : str | list of str The identifier to search for (e.g. "Fp1") or list of identifiers. ch_names : list of str The list of channel names to search in. return_identifiers : bool Whether to return the identifiers of found channels. Only used if identifier is a list. error_level : str The error level to use if the channel is not found. Can be "ignore", "warn", or "raise". Returns ------- ch_name : str | None The actual channel name if found, else None. """ if type(identifier) is list: return find_chs(identifier, ch_names, return_identifiers, error_level) elif not isinstance(identifier, str): raise ValueError("Identifier must be a string or list of strings.") for ch in ch_names: if identifier.lower() in ch.lower(): return ch if error_level == "warn": warnings.warn(f"Channel {identifier} not found in {ch_names}.") elif error_level == "raise": raise ValueError(f"Channel {identifier} not found in {ch_names}.") return None
[docs] def find_chs(identifiers, ch_names, return_identifiers=False, error_level="ignore"): """ Find actual channel names by identifiers (e.g. Fp1 for "EEG Fp1"). Parameters ---------- identifiers : list of str The identifiers to search for (e.g. ["Fp1", "Fp2"]). ch_names : list of str The list of channel names to search in. return_identifiers : bool Whether to return the identifiers of found channels. error_level : str The error level to use if a channel is not found. Can be "ignore", "warn", or "raise". Returns ------- ch_names : list of str | None The actual channel names if found, else None. """ found_identifiers = [] found_ch_names = [] for identifier in identifiers: ch_name = find_ch(identifier, ch_names, error_level=error_level) if ch_name is not None: found_ch_names.append(ch_name) found_identifiers.append(identifier) elif error_level == "raise": raise ValueError(f"Channel {identifier} not found in {ch_names}.") if return_identifiers: return found_ch_names, found_identifiers return found_ch_names
[docs] def find_assoc_deriv_bidspath( bids_path, description, suffix=None, extension=None, alt_description=None ): """ Get bids path for derivative file based on an associated bids path and description. Use find_matching_paths based on subject, session, task, run and specify separately description, suffix and extension. Parameters ---------- bids_path : mne_bids.BIDSPath BIDSPath object with the path to the raw data file. description : str Description of the derivative file. suffix : str Suffix of the derivative file. extension : str Extension of the derivative file. alt_description : str If the first one is not found, allow to specify an alternative. Returns ------- bids_path : mne_bids.BIDSPath BIDSPath object with the path to the derivative file. """ if suffix is None: suffix = bids_path.suffix if extension is None: extension = bids_path.extension derivative_bids_path = find_matching_paths( subjects=bids_path.subject, sessions=bids_path.session, tasks=bids_path.task, runs=bids_path.run, descriptions=description, suffixes=suffix, extensions=extension, root=bids_path.root, ) if len(derivative_bids_path) == 1: return derivative_bids_path[0] else: if alt_description: return find_assoc_deriv_bidspath(bids_path, alt_description, suffix, extension) raise ValueError( f"No derivative file found for desc {description} and subject " f"{bids_path.subject} session {bids_path.session} task " f"{bids_path.task} run {bids_path.run}" )
# def save_assoc_deriv( # bids_path, description, suffix="eeg", extension=".fif", overwrite=False, **kwargs # ): # """ # Save a derivative file associated with a BIDSPath. # Parameters # ---------- # bids_path : mne_bids.BIDSPath # BIDSPath object with the path to the raw data file. # description : str # Description of the derivative file. # suffix : str # Suffix of the derivative file. # extension : str # Extension of the derivative file. # overwrite : bool # Whether to overwrite existing files. # kwargs : dict # Additional keyword arguments to pass to bids_path.copy(). # Returns # ------- # deriv_bids_path : mne_bids.BIDSPath # BIDSPath object with the path to the saved derivative file. # """ # deriv_bids_path = bids_path.copy( # suffix=suffix, # description=description, # extension=extension, # **kwargs, # ) # if # return deriv_bids_path
[docs] def add_annotation_prefix( annotations, prefix="BAD ", regex="(Active Stimulation)|(Buffer Stimulation)" ): """ Add a prefix to annotation descriptions with descriptions matching a regex. Operates in-place. Parameters ---------- annotations : mne.Annotations The annotations object containing the descriptions to modify. prefix : str The prefix to add to the annotation descriptions. regex : str The regex to match the annotation descriptions. Returns ------- annotations : mne.Annotations The annotations object with updated descriptions. """ from re import match for i, annot in enumerate(annotations): if match(regex, annot["description"]): annotations.description[i] = prefix + annotations.description[i] return annotations