Skip to content

Commit

Permalink
Fix to_data_frame bug with tapers
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Dec 11, 2024
1 parent de39d25 commit 80126a7
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 30 deletions.
50 changes: 38 additions & 12 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,12 +1292,15 @@ def test_to_data_frame():
ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"]
n_picks = len(ch_names)
ch_types = ["eeg"] * n_picks
n_tapers = 2
n_freqs = 5
n_times = 6
data = np.random.rand(n_epos, n_picks, n_freqs, n_times)
times = np.arange(6)
data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times)
times = np.arange(n_times)
srate = 1000.0
freqs = np.arange(5)
freqs = np.arange(n_freqs)
tapers = np.arange(n_tapers)
weights = np.ones((n_tapers, n_freqs))
events = np.zeros((n_epos, 3), dtype=int)
events[:, 0] = np.arange(n_epos)
events[:, 2] = np.arange(5, 5 + n_epos)
Expand All @@ -1310,6 +1313,7 @@ def test_to_data_frame():
freqs=freqs,
events=events,
event_id=event_id,
weights=weights,
)
# test index checking
with pytest.raises(ValueError, match="options. Valid index options are"):
Expand All @@ -1321,32 +1325,51 @@ def test_to_data_frame():
# test wide format
df_wide = tfr.to_data_frame()
assert all(np.isin(tfr.ch_names, df_wide.columns))
assert all(np.isin(["time", "condition", "freq", "epoch"], df_wide.columns))
assert all(
np.isin(["time", "condition", "freq", "epoch", "taper"], df_wide.columns)
)
# test long format
df_long = tfr.to_data_frame(long_format=True)
expected = ("condition", "epoch", "freq", "time", "channel", "ch_type", "value")
expected = (
"condition",
"epoch",
"freq",
"time",
"channel",
"ch_type",
"value",
"taper",
)
assert set(expected) == set(df_long.columns)
assert set(tfr.ch_names) == set(df_long["channel"])
assert len(df_long) == tfr.data.size
# test long format w/ index
df_long = tfr.to_data_frame(long_format=True, index=["freq"])
del df_wide, df_long
# test whether data is in correct shape
df = tfr.to_data_frame(index=["condition", "epoch", "freq", "time"])
df = tfr.to_data_frame(index=["condition", "epoch", "taper", "freq", "time"])
data = tfr.data
assert_array_equal(df.values[:, 0], data[:, 0, :, :].reshape(1, -1).squeeze())
# compare arbitrary observation:
assert (
df.loc[("he", slice(None), freqs[1], times[2]), ch_names[3]].iat[0]
== data[1, 3, 1, 2]
df.loc[("he", slice(None), tapers[1], freqs[1], times[2]), ch_names[3]].iat[0]
== data[1, 3, 1, 1, 2]
)

# Check also for AverageTFR:
# (remove taper dimension before averaging)
state = tfr.__getstate__()
state["data"] = state["data"][:, :, 0]
state["dims"] = ("epoch", "channel", "freq", "time")
state["weights"] = None
tfr = EpochsTFR(inst=state)
tfr = tfr.average()
with pytest.raises(ValueError, match="options. Valid index options are"):
tfr.to_data_frame(index=["epoch", "condition"])
with pytest.raises(ValueError, match='"epoch" is not a valid option'):
tfr.to_data_frame(index="epoch")
with pytest.raises(ValueError, match='"taper" is not a valid option'):
tfr.to_data_frame(index="taper")
with pytest.raises(TypeError, match="index must be `None` or a string "):
tfr.to_data_frame(index=np.arange(400))
# test wide format
Expand Down Expand Up @@ -1382,11 +1405,13 @@ def test_to_data_frame_index(index):
ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"]
n_picks = len(ch_names)
ch_types = ["eeg"] * n_picks
n_tapers = 2
n_freqs = 5
n_times = 6
data = np.random.rand(n_epos, n_picks, n_freqs, n_times)
times = np.arange(6)
freqs = np.arange(5)
data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times)
times = np.arange(n_times)
freqs = np.arange(n_freqs)
weights = np.ones((n_tapers, n_freqs))
events = np.zeros((n_epos, 3), dtype=int)
events[:, 0] = np.arange(n_epos)
events[:, 2] = np.arange(5, 8)
Expand All @@ -1399,14 +1424,15 @@ def test_to_data_frame_index(index):
freqs=freqs,
events=events,
event_id=event_id,
weights=weights,
)
df = tfr.to_data_frame(picks=[0, 2, 3], index=index)
# test index order/hierarchy preservation
if not isinstance(index, list):
index = [index]
assert list(df.index.names) == index
# test that non-indexed data were present as columns
non_index = list(set(["condition", "time", "freq", "epoch"]) - set(index))
non_index = list(set(["condition", "time", "freq", "taper", "epoch"]) - set(index))
if len(non_index):
assert all(np.isin(non_index, df.columns))

Expand Down
53 changes: 35 additions & 18 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,6 +1837,7 @@ def get_data(
tmax=None,
return_times=False,
return_freqs=False,
return_tapers=False,
):
"""Get time-frequency data in NumPy array format.
Expand All @@ -1852,6 +1853,10 @@ def get_data(
return_freqs : bool
Whether to return the frequency bin values for the requested
frequency range. Default is ``False``.
return_tapers : bool
Whether to return the taper numbers. Default is ``False``.
.. versionadded:: 1.X.0
Returns
-------
Expand All @@ -1863,6 +1868,9 @@ def get_data(
freqs : array
The frequency values for the requested data range. Only returned if
``return_freqs`` is ``True``.
tapers : array | None
The taper numbers. Only returned if ``return_tapers`` is ``True``. Will be
``None`` if a taper dimension is not present in the data.
Notes
-----
Expand Down Expand Up @@ -1900,7 +1908,13 @@ def get_data(
if return_freqs:
freqs = self._freqs[fmin_idx:fmax_idx]
out.append(freqs)
if not return_times and not return_freqs:
if return_tapers:
if "taper" in self._dims:
tapers = np.arange(self.shape[self._dims.index("taper")])
else:
tapers = None
out.append(tapers)
if not return_times and not return_freqs and not return_tapers:
return out[0]
return tuple(out)

Expand Down Expand Up @@ -2676,21 +2690,21 @@ def to_data_frame(
):
"""Export data in tabular structure as a pandas DataFrame.
Channels are converted to columns in the DataFrame. By default,
additional columns ``'time'``, ``'freq'``, ``'epoch'``, and
``'condition'`` (epoch event description) are added, unless ``index``
is not ``None`` (in which case the columns specified in ``index`` will
be used to form the DataFrame's index instead). ``'epoch'``, and
``'condition'`` are not supported for ``AverageTFR``.
Channels are converted to columns in the DataFrame. By default, additional
columns ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, and ``'condition'``
(epoch event description) are added, unless ``index`` is not ``None`` (in which
case the columns specified in ``index`` will be used to form the DataFrame's
index instead). ``'epoch'``, and ``'condition'`` are not supported for
``AverageTFR``. ``'taper'`` is only supported when a taper dimensions is
present, such as for complex or phase multitaper data.
Parameters
----------
%(picks_all)s
%(index_df_epo)s
Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and
``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'``
for ``AverageTFR``.
Defaults to ``None``.
Valid string values are ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``,
and ``'condition'`` for ``EpochsTFR`` and ``'time'``, ``'freq'``, and
``'taper'`` for ``AverageTFR``. Defaults to ``None``.
%(long_format_df_epo)s
%(time_format_df)s
Expand All @@ -2710,12 +2724,16 @@ def to_data_frame(
valid_index_args = ["time", "freq"]
if from_epo:
valid_index_args.extend(["epoch", "condition"])
if unagg_mt:
valid_index_args.append("taper")
valid_time_formats = ["ms", "timedelta"]
index = _check_pandas_index_arguments(index, valid_index_args)
time_format = _check_time_format(time_format, valid_time_formats)
# get data
picks = _picks_to_idx(self.info, picks, "all", exclude=())
data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True)
data, times, freqs, tapers = self.get_data(
picks, return_times=True, return_freqs=True, return_tapers=True
)
ch_axis = self._dims.index("channel")
if not from_epo:
data = data[np.newaxis] # add singleton "epochs" axis
Expand All @@ -2731,7 +2749,7 @@ def to_data_frame(
default_index = list()
times = _convert_times(times, time_format, self.info["meas_date"])
times = np.tile(times, n_epochs * n_freqs * n_tapers)
freqs = np.tile(np.repeat(freqs, n_times * n_tapers), n_epochs)
freqs = np.tile(np.repeat(freqs, n_times), n_epochs * n_tapers)
mindex.append(("time", times))
mindex.append(("freq", freqs))
if from_epo:
Expand All @@ -2744,12 +2762,11 @@ def to_data_frame(
("condition", np.repeat(conditions, n_times * n_freqs * n_tapers))
)
default_index.extend(["condition", "epoch"])
default_index.extend(["freq", "time"])
if unagg_mt:
name = "taper"
taper_nums = np.tile(np.arange(n_tapers), n_epochs * n_freqs * n_times)
mindex.append((name, taper_nums))
default_index.append(name)
tapers = np.repeat(np.tile(tapers, n_epochs), n_freqs * n_times)
mindex.append(("taper", tapers))
default_index.append("taper")
default_index.extend(["freq", "time"])
assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:])
# build DataFrame
df = _build_data_frame(
Expand Down

0 comments on commit 80126a7

Please sign in to comment.