Skip to content

Commit

Permalink
[ENH] Add option to store and return TFR taper weights (mne-tools#12910)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel McCloy <[email protected]>
Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
3 people authored and qian-chu committed Jan 20, 2025
1 parent 2b02f46 commit e395a47
Show file tree
Hide file tree
Showing 8 changed files with 507 additions and 141 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12910.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added the option to return taper weights from :func:`mne.time_frequency.tfr_array_multitaper`, and taper weights are now stored in the :class:`mne.time_frequency.BaseTFR` objects, by `Thomas Binns`_.
10 changes: 10 additions & 0 deletions mne/time_frequency/multitaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def tfr_array_multitaper(
output="complex",
n_jobs=None,
*,
return_weights=False,
verbose=None,
):
"""Compute Time-Frequency Representation (TFR) using DPSS tapers.
Expand Down Expand Up @@ -504,6 +505,11 @@ def tfr_array_multitaper(
coherence across trials.
%(n_jobs)s
The parallelization is implemented across channels.
return_weights : bool, default False
If True, return the taper weights. Only applies if ``output='complex'`` or
``'phase'``.
.. versionadded:: 1.10.0
%(verbose)s
Returns
Expand All @@ -520,6 +526,9 @@ def tfr_array_multitaper(
If ``output`` is ``'avg_power_itc'``, the real values in ``out``
contain the average power and the imaginary values contain the
inter-trial coherence: :math:`out = power_{avg} + i * ITC`.
weights : array of shape (n_tapers, n_freqs)
The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and
``return_weights=True``.
See Also
--------
Expand Down Expand Up @@ -550,6 +559,7 @@ def tfr_array_multitaper(
use_fft=use_fft,
decim=decim,
output=output,
return_weights=return_weights,
n_jobs=n_jobs,
verbose=verbose,
)
221 changes: 195 additions & 26 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,17 +432,21 @@ def test_tfr_morlet():
def test_dpsswavelet():
"""Test DPSS tapers."""
freqs = np.arange(5, 25, 3)
Ws = _make_dpss(
1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True
Ws, weights = _make_dpss(
1000,
freqs=freqs,
n_cycles=freqs / 2.0,
time_bandwidth=4.0,
zero_mean=True,
return_weights=True,
)

assert len(Ws) == 3 # 3 tapers expected
assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected
assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs)

# Check that zero mean is true
assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5

assert len(Ws[0]) == len(freqs) # As many wavelets as asked for


@pytest.mark.slowtest
def test_tfr_multitaper():
Expand Down Expand Up @@ -664,6 +668,17 @@ def test_tfr_io(inst, average_tfr, request, tmp_path):
with tfr.info._unlock():
tfr.info["meas_date"] = want
assert tfr_loaded == tfr
# test with taper dimension and weights
n_tapers = 3 # anything >= 1 should do
weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs
state = tfr.__getstate__()
state["data"] = np.repeat(np.expand_dims(tfr.data, 2), n_tapers, axis=2) # add dim
state["weights"] = weights # add weights
state["dims"] = ("epoch", "channel", "taper", "freq", "time") # update dims
tfr = EpochsTFR(inst=state)
tfr.save(fname, overwrite=True)
tfr_loaded = read_tfrs(fname)
assert tfr_loaded == tfr
# test overwrite
with pytest.raises(OSError, match="Destination file exists."):
tfr.save(fname, overwrite=False)
Expand Down Expand Up @@ -722,17 +737,31 @@ def test_average_tfr_init(full_evoked):
AverageTFR(inst=full_evoked, method="stockwell", freqs=freqs_linspace)


def test_epochstfr_init_errors(epochs_tfr):
"""Test __init__ for EpochsTFR."""
state = epochs_tfr.__getstate__()
with pytest.raises(ValueError, match="EpochsTFR data should be 4D, got 3"):
EpochsTFR(inst=state | dict(data=epochs_tfr.data[..., 0]))
@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr"))
def test_tfr_init_errors(inst, request, average_tfr):
"""Test __init__ for {Raw,Epochs,Average}TFR."""
# Load data
inst = _get_inst(inst, request, average_tfr=average_tfr)
state = inst.__getstate__()
# Prepare for TFRArray object instantiation
inst_name = inst.__class__.__name__
class_mapping = dict(RawTFR=RawTFR, EpochsTFR=EpochsTFR, AverageTFR=AverageTFR)
ndims_mapping = dict(
RawTFR=("3D or 4D"), EpochsTFR=("4D or 5D"), AverageTFR=("3D or 4D")
)
TFR = class_mapping[inst_name]
allowed_ndims = ndims_mapping[inst_name]
# Check errors caught
with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"):
TFR(inst=state | dict(data=inst.data[..., 0]))
with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"):
TFR(inst=state | dict(data=np.expand_dims(inst.data, axis=(0, 1))))
with pytest.raises(ValueError, match="Channel axis of data .* doesn't match info"):
EpochsTFR(inst=state | dict(data=epochs_tfr.data[:, :-1]))
TFR(inst=state | dict(data=inst.data[..., :-1, :, :]))
with pytest.raises(ValueError, match="Time axis of data.*doesn't match times attr"):
EpochsTFR(inst=state | dict(times=epochs_tfr.times[:-1]))
TFR(inst=state | dict(times=inst.times[:-1]))
with pytest.raises(ValueError, match="Frequency axis of.*doesn't match freqs attr"):
EpochsTFR(inst=state | dict(freqs=epochs_tfr.freqs[:-1]))
TFR(inst=state | dict(freqs=inst.freqs[:-1]))


@pytest.mark.parametrize(
Expand Down Expand Up @@ -830,6 +859,25 @@ def test_plot():
plt.close("all")


@pytest.mark.parametrize("output", ("complex", "phase"))
def test_plot_multitaper_complex_phase(output):
"""Test TFR plotting of data with a taper dimension."""
# Create example data with a taper dimension
n_chans, n_tapers, n_freqs, n_times = (3, 4, 2, 3)
data = np.random.rand(n_chans, n_tapers, n_freqs, n_times)
if output == "complex":
data = data + np.random.rand(*data.shape) * 1j # add imaginary data
times = np.arange(n_times)
freqs = np.arange(n_freqs)
weights = np.random.rand(n_tapers, n_freqs)
info = mne.create_info(n_chans, 1000.0, "eeg")
tfr = AverageTFRArray(
info=info, data=data, times=times, freqs=freqs, weights=weights
)
# Check that plotting works
tfr.plot()


@pytest.mark.parametrize(
"timefreqs,title,combine",
(
Expand Down Expand Up @@ -1154,6 +1202,15 @@ def test_averaging_epochsTFR():
):
power.average(method=np.mean)

# Check it doesn't run for taper spectra
tapered = epochs.compute_tfr(
method="multitaper", freqs=freqs, n_cycles=n_cycles, output="complex"
)
with pytest.raises(
NotImplementedError, match=r"Averaging multitaper tapers .* is not supported."
):
tapered.average()


def test_averaging_freqsandtimes_epochsTFR():
"""Test that EpochsTFR averaging freqs methods work."""
Expand Down Expand Up @@ -1258,12 +1315,15 @@ def test_to_data_frame():
ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"]
n_picks = len(ch_names)
ch_types = ["eeg"] * n_picks
n_tapers = 2
n_freqs = 5
n_times = 6
data = np.random.rand(n_epos, n_picks, n_freqs, n_times)
times = np.arange(6)
data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times)
times = np.arange(n_times)
srate = 1000.0
freqs = np.arange(5)
freqs = np.arange(n_freqs)
tapers = np.arange(n_tapers)
weights = np.ones((n_tapers, n_freqs))
events = np.zeros((n_epos, 3), dtype=int)
events[:, 0] = np.arange(n_epos)
events[:, 2] = np.arange(5, 5 + n_epos)
Expand All @@ -1276,6 +1336,7 @@ def test_to_data_frame():
freqs=freqs,
events=events,
event_id=event_id,
weights=weights,
)
# test index checking
with pytest.raises(ValueError, match="options. Valid index options are"):
Expand All @@ -1287,32 +1348,51 @@ def test_to_data_frame():
# test wide format
df_wide = tfr.to_data_frame()
assert all(np.isin(tfr.ch_names, df_wide.columns))
assert all(np.isin(["time", "condition", "freq", "epoch"], df_wide.columns))
assert all(
np.isin(["time", "condition", "freq", "epoch", "taper"], df_wide.columns)
)
# test long format
df_long = tfr.to_data_frame(long_format=True)
expected = ("condition", "epoch", "freq", "time", "channel", "ch_type", "value")
expected = (
"condition",
"epoch",
"freq",
"time",
"channel",
"ch_type",
"value",
"taper",
)
assert set(expected) == set(df_long.columns)
assert set(tfr.ch_names) == set(df_long["channel"])
assert len(df_long) == tfr.data.size
# test long format w/ index
df_long = tfr.to_data_frame(long_format=True, index=["freq"])
del df_wide, df_long
# test whether data is in correct shape
df = tfr.to_data_frame(index=["condition", "epoch", "freq", "time"])
df = tfr.to_data_frame(index=["condition", "epoch", "taper", "freq", "time"])
data = tfr.data
assert_array_equal(df.values[:, 0], data[:, 0, :, :].reshape(1, -1).squeeze())
# compare arbitrary observation:
assert (
df.loc[("he", slice(None), freqs[1], times[2]), ch_names[3]].iat[0]
== data[1, 3, 1, 2]
df.loc[("he", slice(None), tapers[1], freqs[1], times[2]), ch_names[3]].iat[0]
== data[1, 3, 1, 1, 2]
)

# Check also for AverageTFR:
# (remove taper dimension before averaging)
state = tfr.__getstate__()
state["data"] = state["data"][:, :, 0]
state["dims"] = ("epoch", "channel", "freq", "time")
state["weights"] = None
tfr = EpochsTFR(inst=state)
tfr = tfr.average()
with pytest.raises(ValueError, match="options. Valid index options are"):
tfr.to_data_frame(index=["epoch", "condition"])
with pytest.raises(ValueError, match='"epoch" is not a valid option'):
tfr.to_data_frame(index="epoch")
with pytest.raises(ValueError, match='"taper" is not a valid option'):
tfr.to_data_frame(index="taper")
with pytest.raises(TypeError, match="index must be `None` or a string "):
tfr.to_data_frame(index=np.arange(400))
# test wide format
Expand Down Expand Up @@ -1348,11 +1428,13 @@ def test_to_data_frame_index(index):
ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"]
n_picks = len(ch_names)
ch_types = ["eeg"] * n_picks
n_tapers = 2
n_freqs = 5
n_times = 6
data = np.random.rand(n_epos, n_picks, n_freqs, n_times)
times = np.arange(6)
freqs = np.arange(5)
data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times)
times = np.arange(n_times)
freqs = np.arange(n_freqs)
weights = np.ones((n_tapers, n_freqs))
events = np.zeros((n_epos, 3), dtype=int)
events[:, 0] = np.arange(n_epos)
events[:, 2] = np.arange(5, 8)
Expand All @@ -1365,14 +1447,15 @@ def test_to_data_frame_index(index):
freqs=freqs,
events=events,
event_id=event_id,
weights=weights,
)
df = tfr.to_data_frame(picks=[0, 2, 3], index=index)
# test index order/hierarchy preservation
if not isinstance(index, list):
index = [index]
assert list(df.index.names) == index
# test that non-indexed data were present as columns
non_index = list(set(["condition", "time", "freq", "epoch"]) - set(index))
non_index = list(set(["condition", "time", "freq", "taper", "epoch"]) - set(index))
if len(non_index):
assert all(np.isin(non_index, df.columns))

Expand Down Expand Up @@ -1538,7 +1621,8 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc):
def test_epochs_compute_tfr_multitaper_complex_phase(epochs, output):
"""Test Epochs.compute_tfr(output="complex"/"phase")."""
tfr = epochs.compute_tfr("multitaper", freqs_linspace, output=output)
assert len(tfr.shape) == 5
assert len(tfr.shape) == 5 # epoch x channel x taper x freq x time
assert tfr.weights.shape == tfr.shape[2:4] # check weights and coeffs shapes match


@pytest.mark.parametrize("copy", (False, True))
Expand All @@ -1550,6 +1634,42 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy):
assert avgs[0].comment == str(epochs_tfr.events[0, -1])


@pytest.mark.parametrize("obj_type", ("raw", "epochs", "evoked"))
def test_tfrarray_tapered_spectra(obj_type):
"""Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra."""
# Create example data with a taper dimension
n_epochs, n_chans, n_tapers, n_freqs, n_times = (5, 3, 4, 2, 6)
data_shape = (n_chans, n_tapers, n_freqs, n_times)
if obj_type == "epochs":
data_shape = (n_epochs,) + data_shape
data = np.random.rand(*data_shape)
times = np.arange(n_times)
freqs = np.arange(n_freqs)
weights = np.random.rand(n_tapers, n_freqs)
info = mne.create_info(n_chans, 1000.0, "eeg")
# Prepare for TFRArray object instantiation
defaults = dict(info=info, data=data, times=times, freqs=freqs)
class_mapping = dict(raw=RawTFRArray, epochs=EpochsTFRArray, evoked=AverageTFRArray)
TFRArray = class_mapping[obj_type]
# Check TFRArray instantiation runs with good data
TFRArray(**defaults, weights=weights)
# Check taper dimension but no weights caught
with pytest.raises(
ValueError, match="Taper dimension in data, but no weights found."
):
TFRArray(**defaults)
# Check mismatching n_taper in weights caught
with pytest.raises(
ValueError, match=r"Taper axis .* doesn't match weights attribute"
):
TFRArray(**defaults, weights=weights[:-1])
# Check mismatching n_freq in weights caught
with pytest.raises(
ValueError, match=r"Frequency axis .* doesn't match weights attribute"
):
TFRArray(**defaults, weights=weights[:, :-1])


def test_tfr_proj(epochs):
"""Test `compute_tfr(proj=True)`."""
epochs.compute_tfr(method="morlet", freqs=freqs_linspace, proj=True)
Expand Down Expand Up @@ -1731,3 +1851,52 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request):
assert re.match(
rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title()
)


@pytest.mark.parametrize("output", ("complex", "phase"))
def test_tfr_topo_plotting_multitaper_complex_phase(output, evoked):
"""Test plot_joint/topo/topomap() for data with a taper dimension."""
# Compute TFR with taper dimension
tfr = evoked.compute_tfr(
method="multitaper", freqs=freqs_linspace, n_cycles=4, output=output
)
# Check that plotting works
tfr.plot_joint(topomap_args=dict(res=8, contours=0, sensors=False)) # for speed
tfr.plot_topo()
tfr.plot_topomap()


def test_combine_tfr_error_catch(average_tfr):
"""Test combine_tfr() catches errors."""
# check unrecognised weights string caught
with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'):
combine_tfr([average_tfr, average_tfr], weights="foo")
# check bad weights size caught
with pytest.raises(ValueError, match="Weights must be the same size as all_tfr"):
combine_tfr([average_tfr, average_tfr], weights=[1, 1, 1])
# check different channel names caught
state = average_tfr.__getstate__()
new_info = average_tfr.info.copy()
average_tfr_bad = AverageTFR(
inst=state | dict(info=new_info.rename_channels({new_info.ch_names[0]: "foo"}))
)
with pytest.raises(AssertionError, match=".* do not contain the same channels"):
combine_tfr([average_tfr, average_tfr_bad])
# check different times caught
average_tfr_bad = AverageTFR(inst=state | dict(times=average_tfr.times + 1))
with pytest.raises(
AssertionError, match=".* do not contain the same time instants"
):
combine_tfr([average_tfr, average_tfr_bad])
# check taper dim caught
n_tapers = 3 # anything >= 1 should do
weights = np.ones((n_tapers, average_tfr.shape[1])) # tapers x freqs
state["data"] = np.repeat(np.expand_dims(average_tfr.data, 1), n_tapers, axis=1)
state["weights"] = weights
state["dims"] = ("channel", "taper", "freq", "time")
average_tfr_taper = AverageTFR(inst=state)
with pytest.raises(
NotImplementedError,
match="Aggregating multitaper tapers across TFR datasets is not supported.",
):
combine_tfr([average_tfr_taper, average_tfr_taper])
Loading

0 comments on commit e395a47

Please sign in to comment.