From 5fbd36f1e4130cae11cb5ffa329ce0ce4a3f178e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20M=2E=20K=C3=B6hler?= Date: Fri, 11 Oct 2024 17:34:40 +0200 Subject: [PATCH 1/5] Reject last sample of annotations - Keep bad annotations that end at the start sample (annotations that are potentialy one sample long): end = start - Do not discard bad annotations that are exactly one sample long: onset = end - Also discard last sample of bad annotation: end - start + 1 --- mne/io/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mne/io/base.py b/mne/io/base.py index 79cbbe192ba..b1c473a6ae3 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -912,6 +912,8 @@ def get_data( Whether to reject by annotation. If None (default), no rejection is done. If 'omit', segments annotated with description starting with 'bad' are omitted. If 'NaN', the bad samples are filled with NaNs. + Note that the last sample of each annotation will also be omitted + or replaced with NaN. return_times : bool Whether to return times as well. Defaults to False. %(units)s @@ -983,7 +985,7 @@ def get_data( "reject_by_annotation", reject_by_annotation.lower(), ["omit", "nan"] ) onsets, ends = _annotations_starts_stops(self, ["BAD"]) - keep = (onsets < stop) & (ends > start) + keep = (onsets < stop) & (ends >= start) onsets = np.maximum(onsets[keep], start) ends = np.minimum(ends[keep], stop) if len(onsets) == 0: @@ -996,9 +998,9 @@ def get_data( n_samples = stop - start # total number of samples used = np.ones(n_samples, bool) for onset, end in zip(onsets, ends): - if onset >= end: + if onset > end: continue - used[onset - start : end - start] = False + used[onset - start : end - start + 1] = False used = np.concatenate([[False], used, [False]]) starts = np.where(~used[:-1] & used[1:])[0] + start stops = np.where(used[:-1] & ~used[1:])[0] + start From 47a010460e776145d91bc5712225eaafffc95263 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20M=2E=20K=C3=B6hler?= Date: Fri, 11 Oct 2024 17:35:39 +0200 Subject: [PATCH 2/5] Fix and add tests - Fix now broken tests - Add tests for annotations that are one sample long --- mne/io/tests/test_raw.py | 40 +++++++++++++++++++++++++++-------- mne/tests/test_annotations.py | 10 ++++----- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index 5f4556d6d8e..e7ac3369c23 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -706,29 +706,51 @@ def test_meas_date_orig_time(): def test_get_data_reject(): """Test if reject_by_annotation is working correctly.""" - fs = 256 + fs = 100 ch_names = ["C3", "Cz", "C4"] info = create_info(ch_names, sfreq=fs) - raw = RawArray(np.zeros((len(ch_names), 10 * fs)), info) + n_times = 10 * fs + raw = RawArray(np.zeros((len(ch_names), n_times)), info) raw.set_annotations(Annotations(onset=[2, 4], duration=[3, 2], description="bad")) with catch_logging() as log: data = raw.get_data(reject_by_annotation="omit", verbose=True) msg = ( - "Omitting 1024 of 2560 (40.00%) samples, retaining 1536" - + " (60.00%) samples." + "Omitting 401 of 1000 (40.10%) samples, retaining 599" + + " (59.90%) samples." ) assert log.getvalue().strip() == msg - assert data.shape == (len(ch_names), 1536) + assert data.shape == (len(ch_names), 599) with catch_logging() as log: data = raw.get_data(reject_by_annotation="nan", verbose=True) msg = ( - "Setting 1024 of 2560 (40.00%) samples to NaN, retaining 1536" - + " (60.00%) samples." + "Setting 401 of 1000 (40.10%) samples to NaN, retaining 599" + + " (59.90%) samples." ) assert log.getvalue().strip() == msg - assert data.shape == (len(ch_names), 2560) # shape doesn't change - assert np.isnan(data).sum() == 3072 # but NaNs are introduced instead + assert data.shape == (len(ch_names), n_times) # shape doesn't change + assert np.isnan(data).sum() == 1203 # but NaNs are introduced instead + + # Test that 1-sample annotations at start and end of recording are handled + raw.set_annotations(Annotations(onset=[0], duration=[0], description="bad")) + data = raw.get_data(reject_by_annotation="omit", verbose=True) + assert data.shape == (len(ch_names), n_times - 1) + raw.set_annotations(Annotations(onset=[raw.times[-1]], duration=[0], description="bad")) + data = raw.get_data(reject_by_annotation="omit", verbose=True) + assert data.shape == (len(ch_names), n_times - 1) + + # Test that 1-sample annotations are handled correctly, when they occur + # because of cropping + raw.set_annotations(Annotations(onset=[raw.times[-1]], duration=[1/fs], description="bad")) + with catch_logging() as log: + data = raw.get_data(reject_by_annotation="omit", start=1, verbose=True) + msg = ( + "Omitting 1 of 999 (0.10%) samples, retaining 998 (99.90%)" + + " samples." + ) + assert log.getvalue().strip() == msg + print(data.shape) + assert data.shape == (len(ch_names), n_times - 2) def test_5839(): diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 6b1356ae107..167163ff8b7 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -432,7 +432,7 @@ def test_raw_reject(first_samp): return_times=True, # 1-112 s ) bad_times = np.concatenate( - [np.arange(200, 400), np.arange(10000, 10800), np.arange(10500, 11000)] + [np.arange(200, 401), np.arange(10000, 10801), np.arange(10500, 11001)] ) expected_times = np.setdiff1d(np.arange(100, 11200), bad_times) / sfreq assert_allclose(times, expected_times) @@ -450,7 +450,7 @@ def test_raw_reject(first_samp): t_stop = 18.0 assert raw.times[-1] > t_stop n_stop = int(round(t_stop * raw.info["sfreq"])) - n_drop = int(round(4 * raw.info["sfreq"])) + n_drop = int(round(4 * raw.info["sfreq"]) + 2) assert len(raw.times) >= n_stop data, times = raw.get_data(range(10), 0, n_stop, "omit", True) assert data.shape == (10, n_stop - n_drop) @@ -558,8 +558,8 @@ def test_annotation_filtering(first_samp): raw = raws[0].copy() raw.set_annotations(Annotations([0.0], [0.5], ["BAD_ACQ_SKIP"])) my_data, times = raw.get_data(reject_by_annotation="omit", return_times=True) - assert_allclose(times, raw.times[500:]) - assert my_data.shape == (1, 500) + assert_allclose(times, raw.times[501:]) + assert my_data.shape == (1, 499) raw_filt = raw.copy().filter(skip_by_annotation="bad_acq_skip", **kwargs_stop) expected = data.copy() expected[:, 500:] = 0 @@ -586,7 +586,7 @@ def test_annotation_omit(first_samp): expected = raw[0][0] assert_allclose(raw.get_data(reject_by_annotation=None), expected) # nan - expected[0, 500:1500] = np.nan + expected[0, 500:1501] = np.nan assert_allclose(raw.get_data(reject_by_annotation="nan"), expected) got = np.concatenate( [ From 16dfd8a2b8ba28eedfaac43b65b724ed7955a208 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:48:30 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/io/tests/test_raw.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index e7ac3369c23..3217f3500fc 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -735,19 +735,20 @@ def test_get_data_reject(): raw.set_annotations(Annotations(onset=[0], duration=[0], description="bad")) data = raw.get_data(reject_by_annotation="omit", verbose=True) assert data.shape == (len(ch_names), n_times - 1) - raw.set_annotations(Annotations(onset=[raw.times[-1]], duration=[0], description="bad")) + raw.set_annotations( + Annotations(onset=[raw.times[-1]], duration=[0], description="bad") + ) data = raw.get_data(reject_by_annotation="omit", verbose=True) assert data.shape == (len(ch_names), n_times - 1) # Test that 1-sample annotations are handled correctly, when they occur # because of cropping - raw.set_annotations(Annotations(onset=[raw.times[-1]], duration=[1/fs], description="bad")) + raw.set_annotations( + Annotations(onset=[raw.times[-1]], duration=[1 / fs], description="bad") + ) with catch_logging() as log: data = raw.get_data(reject_by_annotation="omit", start=1, verbose=True) - msg = ( - "Omitting 1 of 999 (0.10%) samples, retaining 998 (99.90%)" - + " samples." - ) + msg = "Omitting 1 of 999 (0.10%) samples, retaining 998 (99.90%)" + " samples." assert log.getvalue().strip() == msg print(data.shape) assert data.shape == (len(ch_names), n_times - 2) From 9651cd045baf33601aa1ce835149fd000c614a2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20M=2E=20K=C3=B6hler?= Date: Thu, 7 Nov 2024 11:03:15 +0100 Subject: [PATCH 4/5] Add include_last as argument to _annotations_starts_stops - Add include_last as keyword argument to _annotations_starts_stops - Adapt corresponding function call in `BaseRaw.get_data()`, and correspondingly remove +1 indexing of used samples - Update documentation of _annotations_starts_stops - Add tests for _annotations_starts_stops - Add towncrier entry for PR --- mne/annotations.py | 13 +++++++--- mne/io/base.py | 4 +-- mne/tests/test_annotations.py | 49 +++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/mne/annotations.py b/mne/annotations.py index 629ee7b20cb..13d0d57dd4c 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1037,10 +1037,15 @@ def _sync_onset(raw, onset, inverse=False): return annot_start -def _annotations_starts_stops(raw, kinds, name="skip_by_annotation", invert=False): - """Get starts and stops from given kinds. +def _annotations_starts_stops( + raw, kinds, name="skip_by_annotation", invert=False, include_last=False +): + """Get starts and stops (i.e., onsets and ends) from given kinds. - onsets and ends are inclusive. + If `include_last` is False (default), ends will indicate the last sample of + the annotations. If `include_last` is True, ends will indicate the last samples +1. + This is useful when for example ``_annotations_starts_stops`` is used to index + entire bad segments in order to reject these. """ _validate_type(kinds, (str, list, tuple), name) if isinstance(kinds, str): @@ -1063,6 +1068,8 @@ def _annotations_starts_stops(raw, kinds, name="skip_by_annotation", invert=Fals ends = onsets + raw.annotations.duration[idxs] onsets = raw.time_as_index(onsets, use_rounding=True) ends = raw.time_as_index(ends, use_rounding=True) + if include_last: + ends += 1 assert (onsets <= ends).all() # all durations >= 0 if invert: # We need to eliminate overlaps here, otherwise wacky things happen, diff --git a/mne/io/base.py b/mne/io/base.py index b1c473a6ae3..3b781c289cb 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -984,7 +984,7 @@ def get_data( _check_option( "reject_by_annotation", reject_by_annotation.lower(), ["omit", "nan"] ) - onsets, ends = _annotations_starts_stops(self, ["BAD"]) + onsets, ends = _annotations_starts_stops(self, ["BAD"], include_last=True) keep = (onsets < stop) & (ends >= start) onsets = np.maximum(onsets[keep], start) ends = np.minimum(ends[keep], stop) @@ -1000,7 +1000,7 @@ def get_data( for onset, end in zip(onsets, ends): if onset > end: continue - used[onset - start : end - start + 1] = False + used[onset - start : end - start] = False used = np.concatenate([[False], used, [False]]) starts = np.where(~used[:-1] & used[1:])[0] + start stops = np.where(used[:-1] & ~used[1:])[0] + start diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 167163ff8b7..4c483713555 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -29,6 +29,7 @@ read_annotations, ) from mne.annotations import ( + _annotations_starts_stops, _handle_meas_date, _read_annotations_txt_parse_header, _sync_onset, @@ -1826,3 +1827,51 @@ def test_append_splits_boundary(tmp_path, split_size): assert len(raw.annotations) == 2 assert raw.annotations.description[0] == "BAD boundary" assert_allclose(raw.annotations.onset, [onset] * 2) + + +def test_annotations_starts_stops(): + """Test _annotations_starts_stops function.""" + sfreq = 10 + info = mne.create_info(1, sfreq, "eeg") + raw = mne.io.RawArray(np.random.RandomState(0).randn(1, 30 * sfreq), info) + annotations = Annotations( + [0, 10, 20], [5, 5, 5], ["BAD", "BAD", "GOOD"], raw.info["meas_date"] + ) + raw.set_annotations(annotations) + + # Test with single kind + onsets, ends = _annotations_starts_stops(raw, "BAD") + assert_array_equal(onsets, raw.time_as_index([0, 10])) + assert_array_equal(ends, raw.time_as_index([5, 15])) + + # Test with multiple kinds + onsets, ends = _annotations_starts_stops(raw, ["BAD", "GOOD"]) + assert_array_equal(onsets, raw.time_as_index([0, 10, 20])) + assert_array_equal(ends, raw.time_as_index([5, 15, 25])) + + # Test with invert=True + onsets, ends = _annotations_starts_stops(raw, "BAD", invert=True) + assert_array_equal(onsets, raw.time_as_index([5, 15])) + assert_array_equal(ends, raw.time_as_index([10, 30])) + + # Test with include_last=True + onsets, ends = _annotations_starts_stops(raw, ["BAD", "GOOD"], include_last=True) + assert_array_equal(onsets, raw.time_as_index([0, 10, 20])) + assert_array_equal(ends, raw.time_as_index([5.1, 15.1, 25.1])) + + # Test with include_last=True and invert=True + onsets, ends = _annotations_starts_stops(raw, "BAD", invert=True, include_last=True) + assert_array_equal(onsets, raw.time_as_index([5.1, 15.1])) + assert_array_equal(ends, raw.time_as_index([10, 30])) + + # Test with no annotations + raw.set_annotations(Annotations([], [], [], raw.info["meas_date"])) + onsets, ends = _annotations_starts_stops(raw, "BAD") + assert_array_equal(onsets, np.array([], int)) + assert_array_equal(ends, np.array([], int)) + + # Test with no matching kinds + raw.set_annotations(Annotations([0], [5], ["GOOD"], raw.info["meas_date"])) + onsets, ends = _annotations_starts_stops(raw, "BAD") + assert_array_equal(onsets, np.array([], int)) + assert_array_equal(ends, np.array([], int)) From 7fb81fc626bb6d4971305031b0997c2223948cbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20M=2E=20K=C3=B6hler?= Date: Thu, 7 Nov 2024 11:03:53 +0100 Subject: [PATCH 5/5] Create 12895.bugfix.rst --- doc/changes/devel/12895.bugfix.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 doc/changes/devel/12895.bugfix.rst diff --git a/doc/changes/devel/12895.bugfix.rst b/doc/changes/devel/12895.bugfix.rst new file mode 100644 index 00000000000..4d31d656427 --- /dev/null +++ b/doc/changes/devel/12895.bugfix.rst @@ -0,0 +1 @@ +Fix bug in :func:`mne.io.BaseRaw.get_data`, where the ``reject_by_annotation`` parameter would not result in rejection of the last sample of the annotation, by `Richard Koehler`_.