From 2922224d4a5fd6aac9d9ee19e969e2aa75129d0b Mon Sep 17 00:00:00 2001 From: jschepers Date: Wed, 26 Jun 2024 09:34:03 +0000 Subject: [PATCH 1/7] added iclabel code from PR, user input which IC labels to keep, user input whether to do eog or ecg detection on ICs --- mne_bids_pipeline/_config.py | 32 +++ mne_bids_pipeline/_config_import.py | 15 + .../steps/preprocessing/_06a1_fit_ica.py | 23 +- .../preprocessing/_06a2_find_ica_artifacts.py | 258 +++++++++++------- .../tests/configs/config_ERP_CORE.py | 8 +- pyproject.toml | 2 + 6 files changed, 232 insertions(+), 106 deletions(-) diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index 5a66cd2c5..e7980d681 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -1370,12 +1370,22 @@ `1` or `None` to not perform any decimation. """ +ica_use_ecg_detection: bool = True +""" +Whether to use the MNE ECG detection on the ICA components. +""" + ica_ecg_threshold: float = 0.1 """ The cross-trial phase statistics (CTPS) threshold parameter used for detecting ECG-related ICs. """ +ica_use_eog_detection: bool = True +""" +Whether to use the MNE EOG detection on the ICA components. +""" + ica_eog_threshold: float = 3.0 """ The threshold to use during automated EOG classification. Lower values mean @@ -1383,6 +1393,28 @@ false-alarm rate increases dramatically. """ + + +# From: https://github.com/mne-tools/mne-bids-pipeline/pull/812 +ica_use_icalabel: bool = False +""" +Whether to use MNE-ICALabel to automatically label ICA components. Only available for +EEG data. +!!! info + Using MNE-ICALabel mandates that you also set: + ```python + eeg_reference = "average" + ica_l_freq = 1 + h_freq = 100 + ``` +""" + +icalabel_include: Annotated[Sequence[Literal["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"]], Len(1, 7)] = [] # TODO: Find out how to use ["brain", "other"] as default +""" +Which independent components (ICs) to keep based on the labels given by ICLabel. +Possible labels are "brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other". +""" + # ### Amplitude-based artifact rejection # # ???+ info "Good Practice / Advice" diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index e6cf83c5a..947f6874b 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -337,6 +337,20 @@ def _check_config(config: SimpleNamespace, config_path: PathLike | None) -> None f"but got shape {destination.shape}" ) +# From: https://github.com/mne-tools/mne-bids-pipeline/pull/812 + # MNE-ICALabel + if config.ica_use_icalabel: + if config.ica_l_freq != 1.0 or config.h_freq != 100.0: + raise ValueError( + f"When using MNE-ICALabel, you must set ica_l_freq=1 and h_freq=100, " + f"but got: ica_l_freq={config.ica_l_freq} and h_freq={config.h_freq}" + ) + + if config.eeg_reference != "average": + raise ValueError( + f'When using MNE-ICALabel, you must set eeg_reference="average", but ' + f"got: eeg_reference={config.eeg_reference}" + ) def _default_factory(key, val): # convert a default to a default factory if needed, having an explicit @@ -346,6 +360,7 @@ def _default_factory(key, val): {"custom": (8, 24.0, 40)}, # decoding_csp_freqs ["evoked"], # inverse_targets [4, 8, 16], # autoreject_n_interpolate + #["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"], # icalabel_include ] for typ in (dict, list): if isinstance(val, typ): diff --git a/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py index 598d2e308..39b4e59a2 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py @@ -74,6 +74,14 @@ def run_ica( """Run ICA.""" import matplotlib.pyplot as plt + if cfg.ica_use_icalabel: + # The ICALabel network was trained on extended-Infomax ICA decompositions fit + # on data flltered between 1 and 100 Hz. + assert cfg.ica_algorithm in ["picard-extended_infomax", "extended_infomax"] + assert cfg.ica_l_freq == 1.0 + assert cfg.h_freq == 100.0 + assert cfg.eeg_reference == "average" + raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] out_files = dict() bids_basename = raw_fnames[0].copy().update(processing=None, split=None, run=None) @@ -158,7 +166,18 @@ def run_ica( # Set an EEG reference if "eeg" in cfg.ch_types: - projection = True if cfg.eeg_reference == "average" else False + if cfg.ica_use_icalabel: + assert cfg.eeg_reference == "average" + projection = False # Avg. ref. needs to be applied for MNE-ICALabel + elif cfg.eeg_reference == "average": + projection = True + else: + projection = False + + if not projection: + msg = "Applying average reference to EEG epochs used for ICA fitting." + logger.info(**gen_log_kwargs(message=msg)) + epochs.set_eeg_reference(cfg.eeg_reference, projection=projection) ar_reject_log = ar_n_interpolate_ = None @@ -331,10 +350,12 @@ def get_config( ica_max_iterations=config.ica_max_iterations, ica_decim=config.ica_decim, ica_reject=config.ica_reject, + ica_use_icalabel=config.ica_use_icalabel, autoreject_n_interpolate=config.autoreject_n_interpolate, random_state=config.random_state, ch_types=config.ch_types, l_freq=config.l_freq, + h_freq=config.h_freq, epochs_decim=config.epochs_decim, raw_resample_sfreq=config.raw_resample_sfreq, event_repeated=config.event_repeated, diff --git a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py index 43f88032a..2acb3ace1 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py @@ -15,6 +15,8 @@ import pandas as pd from mne.preprocessing import create_ecg_epochs, create_eog_epochs from mne_bids import BIDSPath +from mne_icalabel import label_components +import mne_icalabel from ..._config_utils import ( _bids_kwargs, @@ -151,113 +153,144 @@ def find_ica_artifacts( # ECG component detection epochs_ecg = None ecg_ics, ecg_scores = [], [] - for ri, raw_fname in enumerate(raw_fnames): - # Have the channels needed to make ECG epochs - raw = mne.io.read_raw(raw_fname, preload=False) - # ECG epochs - if not ( - "ecg" in raw.get_channel_types() - or "meg" in cfg.ch_types - or "mag" in cfg.ch_types - ): - msg = ( - "No ECG or magnetometer channels are present, cannot " - "automate artifact detection for ECG." + if cfg.ica_use_ecg_detection: + for ri, raw_fname in enumerate(raw_fnames): + # Have the channels needed to make ECG epochs + raw = mne.io.read_raw(raw_fname, preload=False) + # ECG epochs + if not ( + "ecg" in raw.get_channel_types() + or "meg" in cfg.ch_types + or "mag" in cfg.ch_types + ): + msg = ( + "No ECG or magnetometer channels are present, cannot " + "automate artifact detection for ECG." + ) + logger.info(**gen_log_kwargs(message=msg)) + break + elif ri == 0: + msg = "Creating ECG epochs …" + logger.info(**gen_log_kwargs(message=msg)) + + # We want to extract a total of 5 min of data for ECG epochs generation + # (across all runs) + total_ecg_dur = 5 * 60 + ecg_dur_per_run = total_ecg_dur / len(raw_fnames) + t_mid = (raw.times[-1] + raw.times[0]) / 2 + raw = raw.crop( + tmin=max(t_mid - 1 / 2 * ecg_dur_per_run, 0), + tmax=min(t_mid + 1 / 2 * ecg_dur_per_run, raw.times[-1]), + ).load_data() + + these_ecg_epochs = create_ecg_epochs( + raw, + baseline=(None, -0.2), + tmin=-0.5, + tmax=0.5, ) - logger.info(**gen_log_kwargs(message=msg)) - break - elif ri == 0: - msg = "Creating ECG epochs …" - logger.info(**gen_log_kwargs(message=msg)) - - # We want to extract a total of 5 min of data for ECG epochs generation - # (across all runs) - total_ecg_dur = 5 * 60 - ecg_dur_per_run = total_ecg_dur / len(raw_fnames) - t_mid = (raw.times[-1] + raw.times[0]) / 2 - raw = raw.crop( - tmin=max(t_mid - 1 / 2 * ecg_dur_per_run, 0), - tmax=min(t_mid + 1 / 2 * ecg_dur_per_run, raw.times[-1]), - ).load_data() - - these_ecg_epochs = create_ecg_epochs( - raw, - baseline=(None, -0.2), - tmin=-0.5, - tmax=0.5, - ) - del raw # Free memory - if len(these_ecg_epochs): - if epochs.reject is not None: - these_ecg_epochs.drop_bad(reject=epochs.reject) + del raw # Free memory if len(these_ecg_epochs): - if epochs_ecg is None: - epochs_ecg = these_ecg_epochs - else: - epochs_ecg = mne.concatenate_epochs( - [epochs_ecg, these_ecg_epochs], on_mismatch="warn" - ) - del these_ecg_epochs - else: # did not break so had usable channels - ecg_ics, ecg_scores = detect_bad_components( - cfg=cfg, - which="ecg", - epochs=epochs_ecg, - ica=ica, - ch_names=None, # we currently don't allow for custom channels - subject=subject, - session=session, - ) + if epochs.reject is not None: + these_ecg_epochs.drop_bad(reject=epochs.reject) + if len(these_ecg_epochs): + if epochs_ecg is None: + epochs_ecg = these_ecg_epochs + else: + epochs_ecg = mne.concatenate_epochs( + [epochs_ecg, these_ecg_epochs], on_mismatch="warn" + ) + del these_ecg_epochs + else: # did not break so had usable channels + ecg_ics, ecg_scores = detect_bad_components( + cfg=cfg, + which="ecg", + epochs=epochs_ecg, + ica=ica, + ch_names=None, # we currently don't allow for custom channels + subject=subject, + session=session, + ) # EOG component detection epochs_eog = None eog_ics = eog_scores = [] - for ri, raw_fname in enumerate(raw_fnames): - raw = mne.io.read_raw_fif(raw_fname, preload=True) - if cfg.eog_channels: - ch_names = cfg.eog_channels - assert all([ch_name in raw.ch_names for ch_name in ch_names]) - else: - eog_picks = mne.pick_types(raw.info, meg=False, eog=True) - ch_names = [raw.ch_names[pick] for pick in eog_picks] - if not ch_names: - msg = "No EOG channel is present, cannot automate IC detection for EOG." - logger.info(**gen_log_kwargs(message=msg)) - break - elif ri == 0: - msg = "Creating EOG epochs …" - logger.info(**gen_log_kwargs(message=msg)) - these_eog_epochs = create_eog_epochs( - raw, - ch_name=ch_names, - baseline=(None, -0.2), - ) - if len(these_eog_epochs): - if epochs.reject is not None: - these_eog_epochs.drop_bad(reject=epochs.reject) + if cfg.ica_use_eog_detection: + for ri, raw_fname in enumerate(raw_fnames): + raw = mne.io.read_raw_fif(raw_fname, preload=True) + if cfg.eog_channels: + ch_names = cfg.eog_channels + assert all([ch_name in raw.ch_names for ch_name in ch_names]) + else: + eog_picks = mne.pick_types(raw.info, meg=False, eog=True) + ch_names = [raw.ch_names[pick] for pick in eog_picks] + if not ch_names: + msg = "No EOG channel is present, cannot automate IC detection for EOG." + logger.info(**gen_log_kwargs(message=msg)) + break + elif ri == 0: + msg = "Creating EOG epochs …" + logger.info(**gen_log_kwargs(message=msg)) + these_eog_epochs = create_eog_epochs( + raw, + ch_name=ch_names, + baseline=(None, -0.2), + ) if len(these_eog_epochs): - if epochs_eog is None: - epochs_eog = these_eog_epochs - else: - epochs_eog = mne.concatenate_epochs( - [epochs_eog, these_eog_epochs], on_mismatch="warn" - ) - else: # did not break - eog_ics, eog_scores = detect_bad_components( - cfg=cfg, - which="eog", - epochs=epochs_eog, - ica=ica, - ch_names=cfg.eog_channels, - subject=subject, - session=session, + if epochs.reject is not None: + these_eog_epochs.drop_bad(reject=epochs.reject) + if len(these_eog_epochs): + if epochs_eog is None: + epochs_eog = these_eog_epochs + else: + epochs_eog = mne.concatenate_epochs( + [epochs_eog, these_eog_epochs], on_mismatch="warn" + ) + else: # did not break + eog_ics, eog_scores = detect_bad_components( + cfg=cfg, + which="eog", + epochs=epochs_eog, + ica=ica, + ch_names=cfg.eog_channels, + subject=subject, + session=session, + ) + + + # Run MNE-ICALabel if requested. + if cfg.ica_use_icalabel: + icalabel_ics = [] + icalabel_labels = [] + icalabel_prob = [] + msg = "Performing automated artifact detection (MNE-ICALabel) …" + logger.info(**gen_log_kwargs(message=msg)) + + label_results = mne_icalabel.label_components(inst=epochs, ica=ica, method="iclabel") + for idx, (label,prob) in enumerate(zip(label_results["labels"],label_results["y_pred_proba"])): + #icalabel_include = ["brain", "other"] + print(label) + print(prob) + + if label not in cfg.icalabel_include: + icalabel_ics.append(idx) + icalabel_labels.append(label) + icalabel_prob.append(prob) + + msg = ( + f"Detected {len(icalabel_ics)} artifact-related independent component(s) " + f"in {len(epochs)} epochs." ) + logger.info(**gen_log_kwargs(message=msg)) + else: + icalabel_ics = [] + + ica.exclude = sorted(set(ecg_ics + eog_ics + icalabel_ics)) # Save updated ICA to disk. # We also store the automatically identified ECG- and EOG-related ICs. msg = "Saving ICA solution and detected artifacts to disk." logger.info(**gen_log_kwargs(message=msg)) - ica.exclude = sorted(set(ecg_ics + eog_ics)) ica.save(out_files["ica"], overwrite=True) # Create TSV. @@ -271,15 +304,28 @@ def find_ica_artifacts( ) ) - for component in ecg_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected ECG artifact" - - for component in eog_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected EOG artifact" + if cfg.ica_use_icalabel: + assert len(icalabel_ics) == len(icalabel_labels) + for component, label in zip(icalabel_ics, icalabel_labels): + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[ + row_idx, "status_description" + ] = f"Auto-detected {label} (MNE-ICALabel)" + if cfg.ica_use_ecg_detection: + for component in ecg_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[ + row_idx, "status_description" + ] = "Auto-detected ECG artifact (MNE)" + if cfg.ica_use_eog_detection: + for component in eog_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[ + row_idx, "status_description" + ] = "Auto-detected EOG artifact (MNE)" tsv_data.to_csv(out_files_components, sep="\t", index=False) @@ -288,6 +334,8 @@ def find_ica_artifacts( ecg_evoked = None if epochs_ecg is None else epochs_ecg.average() eog_evoked = None if epochs_eog is None else epochs_eog.average() + + ecg_scores = None if len(ecg_scores) == 0 else ecg_scores eog_scores = None if len(eog_scores) == 0 else eog_scores @@ -345,8 +393,12 @@ def get_config( task_is_rest=config.task_is_rest, ica_l_freq=config.ica_l_freq, ica_reject=config.ica_reject, + ica_use_eog_detection = config.ica_use_eog_detection, ica_eog_threshold=config.ica_eog_threshold, + ica_use_ecg_detection = config.ica_use_ecg_detection, ica_ecg_threshold=config.ica_ecg_threshold, + ica_use_icalabel=config.ica_use_icalabel, + icalabel_include = config.icalabel_include, autoreject_n_interpolate=config.autoreject_n_interpolate, random_state=config.random_state, ch_types=config.ch_types, diff --git a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py index e536450fa..650aa7395 100644 --- a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py +++ b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py @@ -75,11 +75,15 @@ spatial_filter = None reject = "autoreject_local" autoreject_n_interpolate = [2, 4] -elif task == "N170": # test autoreject local before ICA +elif task == "N170": # test autoreject local before ICA, and MNE-ICALabel spatial_filter = "ica" + ica_algorithm = "picard-extended_infomax" + ica_use_icalabel = True + ica_l_freq = 1 + h_freq = 100 ica_reject = "autoreject_local" reject = "autoreject_global" - autoreject_n_interpolate = [2, 4] + autoreject_n_interpolate = [12] # only for testing! else: spatial_filter = "ica" ica_reject = dict(eeg=350e-6, eog=500e-6) diff --git a/pyproject.toml b/pyproject.toml index 5ae807b9e..b9ceda85b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,8 @@ dependencies = [ "autoreject", "mne[hdf5] >=1.7", "mne-bids[full]", + "mne-icalabel", + "onnxruntime", # for mne-icalabel "filelock", ] dynamic = ["version"] From 333880790424096c04a673193a5c8b5d7e1a5b6b Mon Sep 17 00:00:00 2001 From: jschepers Date: Wed, 26 Jun 2024 11:26:55 +0000 Subject: [PATCH 2/7] use brain and other as default for icalabel_include --- mne_bids_pipeline/_config.py | 2 +- mne_bids_pipeline/_config_import.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index e7980d681..aba6fd2ce 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -1409,7 +1409,7 @@ ``` """ -icalabel_include: Annotated[Sequence[Literal["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"]], Len(1, 7)] = [] # TODO: Find out how to use ["brain", "other"] as default +icalabel_include: Annotated[Sequence[Literal["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"]], Len(1, 7)] = ["brain","other"] """ Which independent components (ICs) to keep based on the labels given by ICLabel. Possible labels are "brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other". diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index 947f6874b..e947e58f0 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -361,6 +361,7 @@ def _default_factory(key, val): ["evoked"], # inverse_targets [4, 8, 16], # autoreject_n_interpolate #["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"], # icalabel_include + ["brain","other"] ] for typ in (dict, list): if isinstance(val, typ): From 4bf3976ac34735ea0de8e70ed10584c8dc9d04e1 Mon Sep 17 00:00:00 2001 From: jschepers Date: Wed, 26 Jun 2024 14:46:34 +0000 Subject: [PATCH 3/7] added rudimentary plot of excluded components + labels + probabilities --- .../preprocessing/_06a2_find_ica_artifacts.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py index 2acb3ace1..d2ac04e22 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py @@ -14,9 +14,11 @@ import numpy as np import pandas as pd from mne.preprocessing import create_ecg_epochs, create_eog_epochs +from mne.viz import plot_ica_components from mne_bids import BIDSPath from mne_icalabel import label_components import mne_icalabel +import matplotlib.pyplot as plt from ..._config_utils import ( _bids_kwargs, @@ -373,6 +375,22 @@ def find_ica_artifacts( tags=("ica",), # the default but be explicit ) + # Add a plot for each excluded IC together with the given label and the probability + # TODO: Improve this plot e.g. combine all figures in one plot + for ic, label, prob in zip(icalabel_ics, icalabel_labels, icalabel_prob): + excluded_IC_figure = plot_ica_components( + ica=ica, + picks=ic, + ) + excluded_IC_figure.axes[0].text(0, -0.15, f"Label: {label} \n Probability: {prob:.3f}", ha="center", fontsize=8, bbox={"facecolor":"orange", "alpha":0.5, "pad":5}) + + report.add_figure( + fig=excluded_IC_figure, + title = f'ICA{ic:03}', + replace=True, + ) + plt.close(excluded_IC_figure) + msg = 'Carefully review the extracted ICs and mark components "bad" in:' logger.info(**gen_log_kwargs(message=msg, emoji="🛑")) logger.info(**gen_log_kwargs(message=str(out_files_components), emoji="🛑")) From 8ca913273e27451d557e505029c864b44c1564cd Mon Sep 17 00:00:00 2001 From: jschepers Date: Wed, 26 Jun 2024 09:34:03 +0000 Subject: [PATCH 4/7] added iclabel code from PR, user input which IC labels to keep, user input whether to do eog or ecg detection on ICs --- mne_bids_pipeline/_config.py | 32 +++ mne_bids_pipeline/_config_import.py | 15 + .../steps/preprocessing/_06a1_fit_ica.py | 23 +- .../preprocessing/_06a2_find_ica_artifacts.py | 256 +++++++++++------- .../tests/configs/config_ERP_CORE.py | 8 +- pyproject.toml | 2 + 6 files changed, 230 insertions(+), 106 deletions(-) diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index c3abfc57f..8f1cbe242 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -1376,12 +1376,22 @@ `1` or `None` to not perform any decimation. """ +ica_use_ecg_detection: bool = True +""" +Whether to use the MNE ECG detection on the ICA components. +""" + ica_ecg_threshold: float = 0.1 """ The cross-trial phase statistics (CTPS) threshold parameter used for detecting ECG-related ICs. """ +ica_use_eog_detection: bool = True +""" +Whether to use the MNE EOG detection on the ICA components. +""" + ica_eog_threshold: float = 3.0 """ The threshold to use during automated EOG classification. Lower values mean @@ -1389,6 +1399,28 @@ false-alarm rate increases dramatically. """ + + +# From: https://github.com/mne-tools/mne-bids-pipeline/pull/812 +ica_use_icalabel: bool = False +""" +Whether to use MNE-ICALabel to automatically label ICA components. Only available for +EEG data. +!!! info + Using MNE-ICALabel mandates that you also set: + ```python + eeg_reference = "average" + ica_l_freq = 1 + h_freq = 100 + ``` +""" + +icalabel_include: Annotated[Sequence[Literal["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"]], Len(1, 7)] = [] # TODO: Find out how to use ["brain", "other"] as default +""" +Which independent components (ICs) to keep based on the labels given by ICLabel. +Possible labels are "brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other". +""" + # ### Amplitude-based artifact rejection # # ???+ info "Good Practice / Advice" diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index 1616b9ff4..e12f7d20d 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -341,6 +341,20 @@ def _check_config(config: SimpleNamespace, config_path: PathLike | None) -> None f"but got shape {destination.shape}" ) +# From: https://github.com/mne-tools/mne-bids-pipeline/pull/812 + # MNE-ICALabel + if config.ica_use_icalabel: + if config.ica_l_freq != 1.0 or config.h_freq != 100.0: + raise ValueError( + f"When using MNE-ICALabel, you must set ica_l_freq=1 and h_freq=100, " + f"but got: ica_l_freq={config.ica_l_freq} and h_freq={config.h_freq}" + ) + + if config.eeg_reference != "average": + raise ValueError( + f'When using MNE-ICALabel, you must set eeg_reference="average", but ' + f"got: eeg_reference={config.eeg_reference}" + ) def _default_factory(key: str, val: Any) -> Any: # convert a default to a default factory if needed, having an explicit @@ -350,6 +364,7 @@ def _default_factory(key: str, val: Any) -> Any: {"custom": (8, 24.0, 40)}, # decoding_csp_freqs ["evoked"], # inverse_targets [4, 8, 16], # autoreject_n_interpolate + #["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"], # icalabel_include ] def default_factory() -> Any: diff --git a/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py index b2da2a49b..39c7078ce 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py @@ -79,6 +79,14 @@ def run_ica( """Run ICA.""" import matplotlib.pyplot as plt + if cfg.ica_use_icalabel: + # The ICALabel network was trained on extended-Infomax ICA decompositions fit + # on data flltered between 1 and 100 Hz. + assert cfg.ica_algorithm in ["picard-extended_infomax", "extended_infomax"] + assert cfg.ica_l_freq == 1.0 + assert cfg.h_freq == 100.0 + assert cfg.eeg_reference == "average" + raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] out_files = dict() bids_basename = raw_fnames[0].copy().update(processing=None, split=None, run=None) @@ -164,7 +172,18 @@ def run_ica( # Set an EEG reference if "eeg" in cfg.ch_types: - projection = True if cfg.eeg_reference == "average" else False + if cfg.ica_use_icalabel: + assert cfg.eeg_reference == "average" + projection = False # Avg. ref. needs to be applied for MNE-ICALabel + elif cfg.eeg_reference == "average": + projection = True + else: + projection = False + + if not projection: + msg = "Applying average reference to EEG epochs used for ICA fitting." + logger.info(**gen_log_kwargs(message=msg)) + epochs.set_eeg_reference(cfg.eeg_reference, projection=projection) ar_reject_log = ar_n_interpolate_ = None @@ -338,10 +357,12 @@ def get_config( ica_max_iterations=config.ica_max_iterations, ica_decim=config.ica_decim, ica_reject=config.ica_reject, + ica_use_icalabel=config.ica_use_icalabel, autoreject_n_interpolate=config.autoreject_n_interpolate, random_state=config.random_state, ch_types=config.ch_types, l_freq=config.l_freq, + h_freq=config.h_freq, epochs_decim=config.epochs_decim, raw_resample_sfreq=config.raw_resample_sfreq, event_repeated=config.event_repeated, diff --git a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py index 90db43c06..6fcb84ea3 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py @@ -15,6 +15,8 @@ import pandas as pd from mne.preprocessing import create_ecg_epochs, create_eog_epochs from mne_bids import BIDSPath +from mne_icalabel import label_components +import mne_icalabel from mne_bids_pipeline._config_utils import ( _bids_kwargs, @@ -157,114 +159,145 @@ def find_ica_artifacts( epochs_ecg = None ecg_ics: list[int] = [] ecg_scores: FloatArrayT = np.zeros(0) - for ri, raw_fname in enumerate(raw_fnames): - # Have the channels needed to make ECG epochs - raw = mne.io.read_raw(raw_fname, preload=False) - # ECG epochs - if not ( - "ecg" in raw.get_channel_types() - or "meg" in cfg.ch_types - or "mag" in cfg.ch_types - ): - msg = ( - "No ECG or magnetometer channels are present, cannot " - "automate artifact detection for ECG." + if cfg.ica_use_ecg_detection: + for ri, raw_fname in enumerate(raw_fnames): + # Have the channels needed to make ECG epochs + raw = mne.io.read_raw(raw_fname, preload=False) + # ECG epochs + if not ( + "ecg" in raw.get_channel_types() + or "meg" in cfg.ch_types + or "mag" in cfg.ch_types + ): + msg = ( + "No ECG or magnetometer channels are present, cannot " + "automate artifact detection for ECG." + ) + logger.info(**gen_log_kwargs(message=msg)) + break + elif ri == 0: + msg = "Creating ECG epochs …" + logger.info(**gen_log_kwargs(message=msg)) + + # We want to extract a total of 5 min of data for ECG epochs generation + # (across all runs) + total_ecg_dur = 5 * 60 + ecg_dur_per_run = total_ecg_dur / len(raw_fnames) + t_mid = (raw.times[-1] + raw.times[0]) / 2 + raw = raw.crop( + tmin=max(t_mid - 1 / 2 * ecg_dur_per_run, 0), + tmax=min(t_mid + 1 / 2 * ecg_dur_per_run, raw.times[-1]), + ).load_data() + + these_ecg_epochs = create_ecg_epochs( + raw, + baseline=(None, -0.2), + tmin=-0.5, + tmax=0.5, ) - logger.info(**gen_log_kwargs(message=msg)) - break - elif ri == 0: - msg = "Creating ECG epochs …" - logger.info(**gen_log_kwargs(message=msg)) - - # We want to extract a total of 5 min of data for ECG epochs generation - # (across all runs) - total_ecg_dur = 5 * 60 - ecg_dur_per_run = total_ecg_dur / len(raw_fnames) - t_mid = (raw.times[-1] + raw.times[0]) / 2 - raw = raw.crop( - tmin=max(t_mid - 1 / 2 * ecg_dur_per_run, 0), - tmax=min(t_mid + 1 / 2 * ecg_dur_per_run, raw.times[-1]), - ).load_data() - - these_ecg_epochs = create_ecg_epochs( - raw, - baseline=(None, -0.2), - tmin=-0.5, - tmax=0.5, - ) - del raw # Free memory - if len(these_ecg_epochs): - if epochs.reject is not None: - these_ecg_epochs.drop_bad(reject=epochs.reject) + del raw # Free memory if len(these_ecg_epochs): - if epochs_ecg is None: - epochs_ecg = these_ecg_epochs - else: - epochs_ecg = mne.concatenate_epochs( - [epochs_ecg, these_ecg_epochs], on_mismatch="warn" - ) - del these_ecg_epochs - else: # did not break so had usable channels - ecg_ics, ecg_scores = detect_bad_components( - cfg=cfg, - which="ecg", - epochs=epochs_ecg, - ica=ica, - ch_names=None, # we currently don't allow for custom channels - subject=subject, - session=session, - ) + if epochs.reject is not None: + these_ecg_epochs.drop_bad(reject=epochs.reject) + if len(these_ecg_epochs): + if epochs_ecg is None: + epochs_ecg = these_ecg_epochs + else: + epochs_ecg = mne.concatenate_epochs( + [epochs_ecg, these_ecg_epochs], on_mismatch="warn" + ) + del these_ecg_epochs + else: # did not break so had usable channels + ecg_ics, ecg_scores = detect_bad_components( + cfg=cfg, + which="ecg", + epochs=epochs_ecg, + ica=ica, + ch_names=None, # we currently don't allow for custom channels + subject=subject, + session=session, + ) # EOG component detection epochs_eog = None eog_ics: list[int] = [] eog_scores = np.zeros(0) - for ri, raw_fname in enumerate(raw_fnames): - raw = mne.io.read_raw_fif(raw_fname, preload=True) - if cfg.eog_channels: - ch_names = cfg.eog_channels - assert all([ch_name in raw.ch_names for ch_name in ch_names]) - else: - eog_picks = mne.pick_types(raw.info, meg=False, eog=True) - ch_names = [raw.ch_names[pick] for pick in eog_picks] - if not ch_names: - msg = "No EOG channel is present, cannot automate IC detection for EOG." - logger.info(**gen_log_kwargs(message=msg)) - break - elif ri == 0: - msg = "Creating EOG epochs …" - logger.info(**gen_log_kwargs(message=msg)) - these_eog_epochs = create_eog_epochs( - raw, - ch_name=ch_names, - baseline=(None, -0.2), - ) - if len(these_eog_epochs): - if epochs.reject is not None: - these_eog_epochs.drop_bad(reject=epochs.reject) + if cfg.ica_use_eog_detection: + for ri, raw_fname in enumerate(raw_fnames): + raw = mne.io.read_raw_fif(raw_fname, preload=True) + if cfg.eog_channels: + ch_names = cfg.eog_channels + assert all([ch_name in raw.ch_names for ch_name in ch_names]) + else: + eog_picks = mne.pick_types(raw.info, meg=False, eog=True) + ch_names = [raw.ch_names[pick] for pick in eog_picks] + if not ch_names: + msg = "No EOG channel is present, cannot automate IC detection for EOG." + logger.info(**gen_log_kwargs(message=msg)) + break + elif ri == 0: + msg = "Creating EOG epochs …" + logger.info(**gen_log_kwargs(message=msg)) + these_eog_epochs = create_eog_epochs( + raw, + ch_name=ch_names, + baseline=(None, -0.2), + ) if len(these_eog_epochs): - if epochs_eog is None: - epochs_eog = these_eog_epochs - else: - epochs_eog = mne.concatenate_epochs( - [epochs_eog, these_eog_epochs], on_mismatch="warn" - ) - else: # did not break - eog_ics, eog_scores = detect_bad_components( - cfg=cfg, - which="eog", - epochs=epochs_eog, - ica=ica, - ch_names=cfg.eog_channels, - subject=subject, - session=session, + if epochs.reject is not None: + these_eog_epochs.drop_bad(reject=epochs.reject) + if len(these_eog_epochs): + if epochs_eog is None: + epochs_eog = these_eog_epochs + else: + epochs_eog = mne.concatenate_epochs( + [epochs_eog, these_eog_epochs], on_mismatch="warn" + ) + else: # did not break + eog_ics, eog_scores = detect_bad_components( + cfg=cfg, + which="eog", + epochs=epochs_eog, + ica=ica, + ch_names=cfg.eog_channels, + subject=subject, + session=session, + ) + + + # Run MNE-ICALabel if requested. + if cfg.ica_use_icalabel: + icalabel_ics = [] + icalabel_labels = [] + icalabel_prob = [] + msg = "Performing automated artifact detection (MNE-ICALabel) …" + logger.info(**gen_log_kwargs(message=msg)) + + label_results = mne_icalabel.label_components(inst=epochs, ica=ica, method="iclabel") + for idx, (label,prob) in enumerate(zip(label_results["labels"],label_results["y_pred_proba"])): + #icalabel_include = ["brain", "other"] + print(label) + print(prob) + + if label not in cfg.icalabel_include: + icalabel_ics.append(idx) + icalabel_labels.append(label) + icalabel_prob.append(prob) + + msg = ( + f"Detected {len(icalabel_ics)} artifact-related independent component(s) " + f"in {len(epochs)} epochs." ) + logger.info(**gen_log_kwargs(message=msg)) + else: + icalabel_ics = [] + + ica.exclude = sorted(set(ecg_ics + eog_ics + icalabel_ics)) # Save updated ICA to disk. # We also store the automatically identified ECG- and EOG-related ICs. msg = "Saving ICA solution and detected artifacts to disk." logger.info(**gen_log_kwargs(message=msg)) - ica.exclude = sorted(set(ecg_ics + eog_ics)) ica.save(out_files["ica"], overwrite=True) # Create TSV. @@ -278,15 +311,28 @@ def find_ica_artifacts( ) ) - for component in ecg_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected ECG artifact" - - for component in eog_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected EOG artifact" + if cfg.ica_use_icalabel: + assert len(icalabel_ics) == len(icalabel_labels) + for component, label in zip(icalabel_ics, icalabel_labels): + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[ + row_idx, "status_description" + ] = f"Auto-detected {label} (MNE-ICALabel)" + if cfg.ica_use_ecg_detection: + for component in ecg_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[ + row_idx, "status_description" + ] = "Auto-detected ECG artifact (MNE)" + if cfg.ica_use_eog_detection: + for component in eog_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[ + row_idx, "status_description" + ] = "Auto-detected EOG artifact (MNE)" tsv_data.to_csv(out_files_components, sep="\t", index=False) @@ -350,8 +396,12 @@ def get_config( task_is_rest=config.task_is_rest, ica_l_freq=config.ica_l_freq, ica_reject=config.ica_reject, + ica_use_eog_detection = config.ica_use_eog_detection, ica_eog_threshold=config.ica_eog_threshold, + ica_use_ecg_detection = config.ica_use_ecg_detection, ica_ecg_threshold=config.ica_ecg_threshold, + ica_use_icalabel=config.ica_use_icalabel, + icalabel_include = config.icalabel_include, autoreject_n_interpolate=config.autoreject_n_interpolate, random_state=config.random_state, ch_types=config.ch_types, diff --git a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py index e536450fa..650aa7395 100644 --- a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py +++ b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py @@ -75,11 +75,15 @@ spatial_filter = None reject = "autoreject_local" autoreject_n_interpolate = [2, 4] -elif task == "N170": # test autoreject local before ICA +elif task == "N170": # test autoreject local before ICA, and MNE-ICALabel spatial_filter = "ica" + ica_algorithm = "picard-extended_infomax" + ica_use_icalabel = True + ica_l_freq = 1 + h_freq = 100 ica_reject = "autoreject_local" reject = "autoreject_global" - autoreject_n_interpolate = [2, 4] + autoreject_n_interpolate = [12] # only for testing! else: spatial_filter = "ica" ica_reject = dict(eeg=350e-6, eog=500e-6) diff --git a/pyproject.toml b/pyproject.toml index d8643b495..041dfc69a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,8 @@ dependencies = [ "autoreject", "mne[hdf5] >=1.7", "mne-bids[full]", + "mne-icalabel", + "onnxruntime", # for mne-icalabel "filelock", ] dynamic = ["version"] From 5b4d14e6e5878d4e6411707a1845dfdd822b0054 Mon Sep 17 00:00:00 2001 From: jschepers Date: Wed, 26 Jun 2024 11:26:55 +0000 Subject: [PATCH 5/7] use brain and other as default for icalabel_include --- mne_bids_pipeline/_config.py | 2 +- mne_bids_pipeline/_config_import.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index 8f1cbe242..d86a982be 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -1415,7 +1415,7 @@ ``` """ -icalabel_include: Annotated[Sequence[Literal["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"]], Len(1, 7)] = [] # TODO: Find out how to use ["brain", "other"] as default +icalabel_include: Annotated[Sequence[Literal["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"]], Len(1, 7)] = ["brain","other"] """ Which independent components (ICs) to keep based on the labels given by ICLabel. Possible labels are "brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other". diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index e12f7d20d..2cbf3d194 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -365,6 +365,7 @@ def _default_factory(key: str, val: Any) -> Any: ["evoked"], # inverse_targets [4, 8, 16], # autoreject_n_interpolate #["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"], # icalabel_include + ["brain","other"] ] def default_factory() -> Any: From fd304e11baeecccbcf722a2f2aebda1d1cbc6dc8 Mon Sep 17 00:00:00 2001 From: jschepers Date: Wed, 26 Jun 2024 14:46:34 +0000 Subject: [PATCH 6/7] added rudimentary plot of excluded components + labels + probabilities --- .../preprocessing/_06a2_find_ica_artifacts.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py index 6fcb84ea3..93b191a87 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py @@ -14,9 +14,11 @@ import numpy as np import pandas as pd from mne.preprocessing import create_ecg_epochs, create_eog_epochs +from mne.viz import plot_ica_components from mne_bids import BIDSPath from mne_icalabel import label_components import mne_icalabel +import matplotlib.pyplot as plt from mne_bids_pipeline._config_utils import ( _bids_kwargs, @@ -376,6 +378,22 @@ def find_ica_artifacts( tags=("ica",), # the default but be explicit ) + # Add a plot for each excluded IC together with the given label and the probability + # TODO: Improve this plot e.g. combine all figures in one plot + for ic, label, prob in zip(icalabel_ics, icalabel_labels, icalabel_prob): + excluded_IC_figure = plot_ica_components( + ica=ica, + picks=ic, + ) + excluded_IC_figure.axes[0].text(0, -0.15, f"Label: {label} \n Probability: {prob:.3f}", ha="center", fontsize=8, bbox={"facecolor":"orange", "alpha":0.5, "pad":5}) + + report.add_figure( + fig=excluded_IC_figure, + title = f'ICA{ic:03}', + replace=True, + ) + plt.close(excluded_IC_figure) + msg = 'Carefully review the extracted ICs and mark components "bad" in:' logger.info(**gen_log_kwargs(message=msg, emoji="🛑")) logger.info(**gen_log_kwargs(message=str(out_files_components), emoji="🛑")) From 7d2a7295ecd2af39150d1554bd780c7acd40393e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Nov 2024 10:39:30 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne_bids_pipeline/_config.py | 16 +++++- mne_bids_pipeline/_config_import.py | 7 +-- .../preprocessing/_06a2_find_ica_artifacts.py | 53 +++++++++++-------- .../tests/configs/config_ERP_CORE.py | 2 +- 4 files changed, 50 insertions(+), 28 deletions(-) diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index d86a982be..ed9a51e6d 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -1400,7 +1400,6 @@ """ - # From: https://github.com/mne-tools/mne-bids-pipeline/pull/812 ica_use_icalabel: bool = False """ @@ -1415,7 +1414,20 @@ ``` """ -icalabel_include: Annotated[Sequence[Literal["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"]], Len(1, 7)] = ["brain","other"] +icalabel_include: Annotated[ + Sequence[ + Literal[ + "brain", + "muscle artifact", + "eye blink", + "heart beat", + "line noise", + "channel noise", + "other", + ] + ], + Len(1, 7), +] = ["brain", "other"] """ Which independent components (ICs) to keep based on the labels given by ICLabel. Possible labels are "brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other". diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index 2cbf3d194..3c2206277 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -341,7 +341,7 @@ def _check_config(config: SimpleNamespace, config_path: PathLike | None) -> None f"but got shape {destination.shape}" ) -# From: https://github.com/mne-tools/mne-bids-pipeline/pull/812 + # From: https://github.com/mne-tools/mne-bids-pipeline/pull/812 # MNE-ICALabel if config.ica_use_icalabel: if config.ica_l_freq != 1.0 or config.h_freq != 100.0: @@ -356,6 +356,7 @@ def _check_config(config: SimpleNamespace, config_path: PathLike | None) -> None f"got: eeg_reference={config.eeg_reference}" ) + def _default_factory(key: str, val: Any) -> Any: # convert a default to a default factory if needed, having an explicit # allowlist of non-empty ones @@ -364,8 +365,8 @@ def _default_factory(key: str, val: Any) -> Any: {"custom": (8, 24.0, 40)}, # decoding_csp_freqs ["evoked"], # inverse_targets [4, 8, 16], # autoreject_n_interpolate - #["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"], # icalabel_include - ["brain","other"] + # ["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"], # icalabel_include + ["brain", "other"], ] def default_factory() -> Any: diff --git a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py index 93b191a87..05d984c54 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py @@ -10,15 +10,14 @@ from types import SimpleNamespace from typing import Literal +import matplotlib.pyplot as plt import mne +import mne_icalabel import numpy as np import pandas as pd from mne.preprocessing import create_ecg_epochs, create_eog_epochs from mne.viz import plot_ica_components from mne_bids import BIDSPath -from mne_icalabel import label_components -import mne_icalabel -import matplotlib.pyplot as plt from mne_bids_pipeline._config_utils import ( _bids_kwargs, @@ -265,7 +264,6 @@ def find_ica_artifacts( subject=subject, session=session, ) - # Run MNE-ICALabel if requested. if cfg.ica_use_icalabel: @@ -274,10 +272,14 @@ def find_ica_artifacts( icalabel_prob = [] msg = "Performing automated artifact detection (MNE-ICALabel) …" logger.info(**gen_log_kwargs(message=msg)) - - label_results = mne_icalabel.label_components(inst=epochs, ica=ica, method="iclabel") - for idx, (label,prob) in enumerate(zip(label_results["labels"],label_results["y_pred_proba"])): - #icalabel_include = ["brain", "other"] + + label_results = mne_icalabel.label_components( + inst=epochs, ica=ica, method="iclabel" + ) + for idx, (label, prob) in enumerate( + zip(label_results["labels"], label_results["y_pred_proba"]) + ): + # icalabel_include = ["brain", "other"] print(label) print(prob) @@ -318,23 +320,23 @@ def find_ica_artifacts( for component, label in zip(icalabel_ics, icalabel_labels): row_idx = tsv_data["component"] == component tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[ - row_idx, "status_description" - ] = f"Auto-detected {label} (MNE-ICALabel)" + tsv_data.loc[row_idx, "status_description"] = ( + f"Auto-detected {label} (MNE-ICALabel)" + ) if cfg.ica_use_ecg_detection: for component in ecg_ics: row_idx = tsv_data["component"] == component tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[ - row_idx, "status_description" - ] = "Auto-detected ECG artifact (MNE)" + tsv_data.loc[row_idx, "status_description"] = ( + "Auto-detected ECG artifact (MNE)" + ) if cfg.ica_use_eog_detection: for component in eog_ics: row_idx = tsv_data["component"] == component tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[ - row_idx, "status_description" - ] = "Auto-detected EOG artifact (MNE)" + tsv_data.loc[row_idx, "status_description"] = ( + "Auto-detected EOG artifact (MNE)" + ) tsv_data.to_csv(out_files_components, sep="\t", index=False) @@ -385,11 +387,18 @@ def find_ica_artifacts( ica=ica, picks=ic, ) - excluded_IC_figure.axes[0].text(0, -0.15, f"Label: {label} \n Probability: {prob:.3f}", ha="center", fontsize=8, bbox={"facecolor":"orange", "alpha":0.5, "pad":5}) + excluded_IC_figure.axes[0].text( + 0, + -0.15, + f"Label: {label} \n Probability: {prob:.3f}", + ha="center", + fontsize=8, + bbox={"facecolor": "orange", "alpha": 0.5, "pad": 5}, + ) report.add_figure( fig=excluded_IC_figure, - title = f'ICA{ic:03}', + title=f"ICA{ic:03}", replace=True, ) plt.close(excluded_IC_figure) @@ -414,12 +423,12 @@ def get_config( task_is_rest=config.task_is_rest, ica_l_freq=config.ica_l_freq, ica_reject=config.ica_reject, - ica_use_eog_detection = config.ica_use_eog_detection, + ica_use_eog_detection=config.ica_use_eog_detection, ica_eog_threshold=config.ica_eog_threshold, - ica_use_ecg_detection = config.ica_use_ecg_detection, + ica_use_ecg_detection=config.ica_use_ecg_detection, ica_ecg_threshold=config.ica_ecg_threshold, ica_use_icalabel=config.ica_use_icalabel, - icalabel_include = config.icalabel_include, + icalabel_include=config.icalabel_include, autoreject_n_interpolate=config.autoreject_n_interpolate, random_state=config.random_state, ch_types=config.ch_types, diff --git a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py index 650aa7395..3915db661 100644 --- a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py +++ b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py @@ -83,7 +83,7 @@ h_freq = 100 ica_reject = "autoreject_local" reject = "autoreject_global" - autoreject_n_interpolate = [12] # only for testing! + autoreject_n_interpolate = [12] # only for testing! else: spatial_filter = "ica" ica_reject = dict(eeg=350e-6, eog=500e-6)