diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index c3abfc57f..ed9a51e6d 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -1376,12 +1376,22 @@ `1` or `None` to not perform any decimation. """ +ica_use_ecg_detection: bool = True +""" +Whether to use the MNE ECG detection on the ICA components. +""" + ica_ecg_threshold: float = 0.1 """ The cross-trial phase statistics (CTPS) threshold parameter used for detecting ECG-related ICs. """ +ica_use_eog_detection: bool = True +""" +Whether to use the MNE EOG detection on the ICA components. +""" + ica_eog_threshold: float = 3.0 """ The threshold to use during automated EOG classification. Lower values mean @@ -1389,6 +1399,40 @@ false-alarm rate increases dramatically. """ + +# From: https://github.com/mne-tools/mne-bids-pipeline/pull/812 +ica_use_icalabel: bool = False +""" +Whether to use MNE-ICALabel to automatically label ICA components. Only available for +EEG data. +!!! info + Using MNE-ICALabel mandates that you also set: + ```python + eeg_reference = "average" + ica_l_freq = 1 + h_freq = 100 + ``` +""" + +icalabel_include: Annotated[ + Sequence[ + Literal[ + "brain", + "muscle artifact", + "eye blink", + "heart beat", + "line noise", + "channel noise", + "other", + ] + ], + Len(1, 7), +] = ["brain", "other"] +""" +Which independent components (ICs) to keep based on the labels given by ICLabel. +Possible labels are "brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other". +""" + # ### Amplitude-based artifact rejection # # ???+ info "Good Practice / Advice" diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index 1616b9ff4..3c2206277 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -341,6 +341,21 @@ def _check_config(config: SimpleNamespace, config_path: PathLike | None) -> None f"but got shape {destination.shape}" ) + # From: https://github.com/mne-tools/mne-bids-pipeline/pull/812 + # MNE-ICALabel + if config.ica_use_icalabel: + if config.ica_l_freq != 1.0 or config.h_freq != 100.0: + raise ValueError( + f"When using MNE-ICALabel, you must set ica_l_freq=1 and h_freq=100, " + f"but got: ica_l_freq={config.ica_l_freq} and h_freq={config.h_freq}" + ) + + if config.eeg_reference != "average": + raise ValueError( + f'When using MNE-ICALabel, you must set eeg_reference="average", but ' + f"got: eeg_reference={config.eeg_reference}" + ) + def _default_factory(key: str, val: Any) -> Any: # convert a default to a default factory if needed, having an explicit @@ -350,6 +365,8 @@ def _default_factory(key: str, val: Any) -> Any: {"custom": (8, 24.0, 40)}, # decoding_csp_freqs ["evoked"], # inverse_targets [4, 8, 16], # autoreject_n_interpolate + # ["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"], # icalabel_include + ["brain", "other"], ] def default_factory() -> Any: diff --git a/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py index b2da2a49b..39c7078ce 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py @@ -79,6 +79,14 @@ def run_ica( """Run ICA.""" import matplotlib.pyplot as plt + if cfg.ica_use_icalabel: + # The ICALabel network was trained on extended-Infomax ICA decompositions fit + # on data flltered between 1 and 100 Hz. + assert cfg.ica_algorithm in ["picard-extended_infomax", "extended_infomax"] + assert cfg.ica_l_freq == 1.0 + assert cfg.h_freq == 100.0 + assert cfg.eeg_reference == "average" + raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] out_files = dict() bids_basename = raw_fnames[0].copy().update(processing=None, split=None, run=None) @@ -164,7 +172,18 @@ def run_ica( # Set an EEG reference if "eeg" in cfg.ch_types: - projection = True if cfg.eeg_reference == "average" else False + if cfg.ica_use_icalabel: + assert cfg.eeg_reference == "average" + projection = False # Avg. ref. needs to be applied for MNE-ICALabel + elif cfg.eeg_reference == "average": + projection = True + else: + projection = False + + if not projection: + msg = "Applying average reference to EEG epochs used for ICA fitting." + logger.info(**gen_log_kwargs(message=msg)) + epochs.set_eeg_reference(cfg.eeg_reference, projection=projection) ar_reject_log = ar_n_interpolate_ = None @@ -338,10 +357,12 @@ def get_config( ica_max_iterations=config.ica_max_iterations, ica_decim=config.ica_decim, ica_reject=config.ica_reject, + ica_use_icalabel=config.ica_use_icalabel, autoreject_n_interpolate=config.autoreject_n_interpolate, random_state=config.random_state, ch_types=config.ch_types, l_freq=config.l_freq, + h_freq=config.h_freq, epochs_decim=config.epochs_decim, raw_resample_sfreq=config.raw_resample_sfreq, event_repeated=config.event_repeated, diff --git a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py index 90db43c06..05d984c54 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py @@ -10,10 +10,13 @@ from types import SimpleNamespace from typing import Literal +import matplotlib.pyplot as plt import mne +import mne_icalabel import numpy as np import pandas as pd from mne.preprocessing import create_ecg_epochs, create_eog_epochs +from mne.viz import plot_ica_components from mne_bids import BIDSPath from mne_bids_pipeline._config_utils import ( @@ -157,114 +160,148 @@ def find_ica_artifacts( epochs_ecg = None ecg_ics: list[int] = [] ecg_scores: FloatArrayT = np.zeros(0) - for ri, raw_fname in enumerate(raw_fnames): - # Have the channels needed to make ECG epochs - raw = mne.io.read_raw(raw_fname, preload=False) - # ECG epochs - if not ( - "ecg" in raw.get_channel_types() - or "meg" in cfg.ch_types - or "mag" in cfg.ch_types - ): - msg = ( - "No ECG or magnetometer channels are present, cannot " - "automate artifact detection for ECG." + if cfg.ica_use_ecg_detection: + for ri, raw_fname in enumerate(raw_fnames): + # Have the channels needed to make ECG epochs + raw = mne.io.read_raw(raw_fname, preload=False) + # ECG epochs + if not ( + "ecg" in raw.get_channel_types() + or "meg" in cfg.ch_types + or "mag" in cfg.ch_types + ): + msg = ( + "No ECG or magnetometer channels are present, cannot " + "automate artifact detection for ECG." + ) + logger.info(**gen_log_kwargs(message=msg)) + break + elif ri == 0: + msg = "Creating ECG epochs …" + logger.info(**gen_log_kwargs(message=msg)) + + # We want to extract a total of 5 min of data for ECG epochs generation + # (across all runs) + total_ecg_dur = 5 * 60 + ecg_dur_per_run = total_ecg_dur / len(raw_fnames) + t_mid = (raw.times[-1] + raw.times[0]) / 2 + raw = raw.crop( + tmin=max(t_mid - 1 / 2 * ecg_dur_per_run, 0), + tmax=min(t_mid + 1 / 2 * ecg_dur_per_run, raw.times[-1]), + ).load_data() + + these_ecg_epochs = create_ecg_epochs( + raw, + baseline=(None, -0.2), + tmin=-0.5, + tmax=0.5, ) - logger.info(**gen_log_kwargs(message=msg)) - break - elif ri == 0: - msg = "Creating ECG epochs …" - logger.info(**gen_log_kwargs(message=msg)) - - # We want to extract a total of 5 min of data for ECG epochs generation - # (across all runs) - total_ecg_dur = 5 * 60 - ecg_dur_per_run = total_ecg_dur / len(raw_fnames) - t_mid = (raw.times[-1] + raw.times[0]) / 2 - raw = raw.crop( - tmin=max(t_mid - 1 / 2 * ecg_dur_per_run, 0), - tmax=min(t_mid + 1 / 2 * ecg_dur_per_run, raw.times[-1]), - ).load_data() - - these_ecg_epochs = create_ecg_epochs( - raw, - baseline=(None, -0.2), - tmin=-0.5, - tmax=0.5, - ) - del raw # Free memory - if len(these_ecg_epochs): - if epochs.reject is not None: - these_ecg_epochs.drop_bad(reject=epochs.reject) + del raw # Free memory if len(these_ecg_epochs): - if epochs_ecg is None: - epochs_ecg = these_ecg_epochs - else: - epochs_ecg = mne.concatenate_epochs( - [epochs_ecg, these_ecg_epochs], on_mismatch="warn" - ) - del these_ecg_epochs - else: # did not break so had usable channels - ecg_ics, ecg_scores = detect_bad_components( - cfg=cfg, - which="ecg", - epochs=epochs_ecg, - ica=ica, - ch_names=None, # we currently don't allow for custom channels - subject=subject, - session=session, - ) + if epochs.reject is not None: + these_ecg_epochs.drop_bad(reject=epochs.reject) + if len(these_ecg_epochs): + if epochs_ecg is None: + epochs_ecg = these_ecg_epochs + else: + epochs_ecg = mne.concatenate_epochs( + [epochs_ecg, these_ecg_epochs], on_mismatch="warn" + ) + del these_ecg_epochs + else: # did not break so had usable channels + ecg_ics, ecg_scores = detect_bad_components( + cfg=cfg, + which="ecg", + epochs=epochs_ecg, + ica=ica, + ch_names=None, # we currently don't allow for custom channels + subject=subject, + session=session, + ) # EOG component detection epochs_eog = None eog_ics: list[int] = [] eog_scores = np.zeros(0) - for ri, raw_fname in enumerate(raw_fnames): - raw = mne.io.read_raw_fif(raw_fname, preload=True) - if cfg.eog_channels: - ch_names = cfg.eog_channels - assert all([ch_name in raw.ch_names for ch_name in ch_names]) - else: - eog_picks = mne.pick_types(raw.info, meg=False, eog=True) - ch_names = [raw.ch_names[pick] for pick in eog_picks] - if not ch_names: - msg = "No EOG channel is present, cannot automate IC detection for EOG." - logger.info(**gen_log_kwargs(message=msg)) - break - elif ri == 0: - msg = "Creating EOG epochs …" - logger.info(**gen_log_kwargs(message=msg)) - these_eog_epochs = create_eog_epochs( - raw, - ch_name=ch_names, - baseline=(None, -0.2), - ) - if len(these_eog_epochs): - if epochs.reject is not None: - these_eog_epochs.drop_bad(reject=epochs.reject) + if cfg.ica_use_eog_detection: + for ri, raw_fname in enumerate(raw_fnames): + raw = mne.io.read_raw_fif(raw_fname, preload=True) + if cfg.eog_channels: + ch_names = cfg.eog_channels + assert all([ch_name in raw.ch_names for ch_name in ch_names]) + else: + eog_picks = mne.pick_types(raw.info, meg=False, eog=True) + ch_names = [raw.ch_names[pick] for pick in eog_picks] + if not ch_names: + msg = "No EOG channel is present, cannot automate IC detection for EOG." + logger.info(**gen_log_kwargs(message=msg)) + break + elif ri == 0: + msg = "Creating EOG epochs …" + logger.info(**gen_log_kwargs(message=msg)) + these_eog_epochs = create_eog_epochs( + raw, + ch_name=ch_names, + baseline=(None, -0.2), + ) if len(these_eog_epochs): - if epochs_eog is None: - epochs_eog = these_eog_epochs - else: - epochs_eog = mne.concatenate_epochs( - [epochs_eog, these_eog_epochs], on_mismatch="warn" - ) - else: # did not break - eog_ics, eog_scores = detect_bad_components( - cfg=cfg, - which="eog", - epochs=epochs_eog, - ica=ica, - ch_names=cfg.eog_channels, - subject=subject, - session=session, + if epochs.reject is not None: + these_eog_epochs.drop_bad(reject=epochs.reject) + if len(these_eog_epochs): + if epochs_eog is None: + epochs_eog = these_eog_epochs + else: + epochs_eog = mne.concatenate_epochs( + [epochs_eog, these_eog_epochs], on_mismatch="warn" + ) + else: # did not break + eog_ics, eog_scores = detect_bad_components( + cfg=cfg, + which="eog", + epochs=epochs_eog, + ica=ica, + ch_names=cfg.eog_channels, + subject=subject, + session=session, + ) + + # Run MNE-ICALabel if requested. + if cfg.ica_use_icalabel: + icalabel_ics = [] + icalabel_labels = [] + icalabel_prob = [] + msg = "Performing automated artifact detection (MNE-ICALabel) …" + logger.info(**gen_log_kwargs(message=msg)) + + label_results = mne_icalabel.label_components( + inst=epochs, ica=ica, method="iclabel" + ) + for idx, (label, prob) in enumerate( + zip(label_results["labels"], label_results["y_pred_proba"]) + ): + # icalabel_include = ["brain", "other"] + print(label) + print(prob) + + if label not in cfg.icalabel_include: + icalabel_ics.append(idx) + icalabel_labels.append(label) + icalabel_prob.append(prob) + + msg = ( + f"Detected {len(icalabel_ics)} artifact-related independent component(s) " + f"in {len(epochs)} epochs." ) + logger.info(**gen_log_kwargs(message=msg)) + else: + icalabel_ics = [] + + ica.exclude = sorted(set(ecg_ics + eog_ics + icalabel_ics)) # Save updated ICA to disk. # We also store the automatically identified ECG- and EOG-related ICs. msg = "Saving ICA solution and detected artifacts to disk." logger.info(**gen_log_kwargs(message=msg)) - ica.exclude = sorted(set(ecg_ics + eog_ics)) ica.save(out_files["ica"], overwrite=True) # Create TSV. @@ -278,15 +315,28 @@ def find_ica_artifacts( ) ) - for component in ecg_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected ECG artifact" - - for component in eog_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected EOG artifact" + if cfg.ica_use_icalabel: + assert len(icalabel_ics) == len(icalabel_labels) + for component, label in zip(icalabel_ics, icalabel_labels): + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[row_idx, "status_description"] = ( + f"Auto-detected {label} (MNE-ICALabel)" + ) + if cfg.ica_use_ecg_detection: + for component in ecg_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[row_idx, "status_description"] = ( + "Auto-detected ECG artifact (MNE)" + ) + if cfg.ica_use_eog_detection: + for component in eog_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[row_idx, "status_description"] = ( + "Auto-detected EOG artifact (MNE)" + ) tsv_data.to_csv(out_files_components, sep="\t", index=False) @@ -330,6 +380,29 @@ def find_ica_artifacts( tags=("ica",), # the default but be explicit ) + # Add a plot for each excluded IC together with the given label and the probability + # TODO: Improve this plot e.g. combine all figures in one plot + for ic, label, prob in zip(icalabel_ics, icalabel_labels, icalabel_prob): + excluded_IC_figure = plot_ica_components( + ica=ica, + picks=ic, + ) + excluded_IC_figure.axes[0].text( + 0, + -0.15, + f"Label: {label} \n Probability: {prob:.3f}", + ha="center", + fontsize=8, + bbox={"facecolor": "orange", "alpha": 0.5, "pad": 5}, + ) + + report.add_figure( + fig=excluded_IC_figure, + title=f"ICA{ic:03}", + replace=True, + ) + plt.close(excluded_IC_figure) + msg = 'Carefully review the extracted ICs and mark components "bad" in:' logger.info(**gen_log_kwargs(message=msg, emoji="🛑")) logger.info(**gen_log_kwargs(message=str(out_files_components), emoji="🛑")) @@ -350,8 +423,12 @@ def get_config( task_is_rest=config.task_is_rest, ica_l_freq=config.ica_l_freq, ica_reject=config.ica_reject, + ica_use_eog_detection=config.ica_use_eog_detection, ica_eog_threshold=config.ica_eog_threshold, + ica_use_ecg_detection=config.ica_use_ecg_detection, ica_ecg_threshold=config.ica_ecg_threshold, + ica_use_icalabel=config.ica_use_icalabel, + icalabel_include=config.icalabel_include, autoreject_n_interpolate=config.autoreject_n_interpolate, random_state=config.random_state, ch_types=config.ch_types, diff --git a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py index e536450fa..3915db661 100644 --- a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py +++ b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py @@ -75,11 +75,15 @@ spatial_filter = None reject = "autoreject_local" autoreject_n_interpolate = [2, 4] -elif task == "N170": # test autoreject local before ICA +elif task == "N170": # test autoreject local before ICA, and MNE-ICALabel spatial_filter = "ica" + ica_algorithm = "picard-extended_infomax" + ica_use_icalabel = True + ica_l_freq = 1 + h_freq = 100 ica_reject = "autoreject_local" reject = "autoreject_global" - autoreject_n_interpolate = [2, 4] + autoreject_n_interpolate = [12] # only for testing! else: spatial_filter = "ica" ica_reject = dict(eeg=350e-6, eog=500e-6) diff --git a/pyproject.toml b/pyproject.toml index d8643b495..041dfc69a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,8 @@ dependencies = [ "autoreject", "mne[hdf5] >=1.7", "mne-bids[full]", + "mne-icalabel", + "onnxruntime", # for mne-icalabel "filelock", ] dynamic = ["version"]