-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GSOC] Add EpochsTFR
support to spectral connectivity functions
#232
Conversation
Have now added support for |
a86f9b9
to
7013e19
Compare
This reverts commit 6fd0861.
Resetting tests to main mne branch now that mne-tools/mne-python#12842 is merged |
Thanks for the detailed review @drammock! Until those wider API questions are addressed I'll double check the unit tests and try to find where the deviation in Fourier/Welch pipelines. |
Co-authored-by: Daniel McCloy <[email protected]>
for more information, see https://pre-commit.ci
Conda test keeps freezing on the micromamba setup step, seems similar to this: mamba-org/setup-micromamba#225 |
Yeah feel free! Sounds like there might be a workaround via pinning in some of those issue comments |
Haven't forgotten about this, just need to sort mne-tools/mne-python#12910 before I can finalise the changes and tests here. |
mne-tools/mne-python#12910 is now merged, so I can fix the outstanding issues here. |
def _tfr_csd_from_mt(x_mt, y_mt, weights_x, weights_y): | ||
"""Compute time-frequency CSD from tapered spectra. | ||
|
||
Parameters | ||
---------- | ||
x_mt : array, shape (..., n_tapers, n_freqs, n_times) | ||
The tapered time-frequency spectra for signals x. | ||
y_mt : array, shape (..., n_tapers, n_freqs, n_times) | ||
The tapered time-frequency spectra for signals y. | ||
weights_x : array, shape (n_tapers, n_freqs) | ||
Weights to use for combining the tapered spectra of x_mt. | ||
weights_y : array, shape (n_tapers, n_freqs) | ||
Weights to use for combining the tapered spectra of y_mt. | ||
|
||
Returns | ||
------- | ||
csd : array, shape (..., n_freqs, n_times) | ||
The CSD between x and y. | ||
""" | ||
# expand weights dims to match x_mt and y_mt | ||
weights_x = np.expand_dims(weights_x, axis=(*np.arange(x_mt.ndim - 3), -1)) | ||
weights_y = np.expand_dims(weights_y, axis=(*np.arange(y_mt.ndim - 3), -1)) | ||
# compute CSD | ||
csd = np.sum(weights_x * x_mt * (weights_y * y_mt).conj(), axis=-3) | ||
denom = np.sqrt((weights_x * weights_x.conj()).real.sum(axis=-3)) * np.sqrt( | ||
(weights_y * weights_y.conj()).real.sum(axis=-3) | ||
) | ||
csd *= 2 / denom | ||
return csd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There isn't a _csd_from_mt
equivalent for TFR data in MNE. Would people be open to me adding this to the main MNE package in a small PR?
Otherwise it can just remain here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think even if you put it there you'd need some sort of try/except
based on MNE version with a fallback to this code (in this codebase) anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. So we should keep this here for now.
if not is_tfr_con: # normal spectra (multitaper or Fourier) | ||
this_psd = _psd_from_mt(x_t, weights) | ||
else: # TFR spectra (multitaper) | ||
this_psd = np.array([_tfr_from_mt(epo_x, weights) for epo_x in x_t]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Realising now that the way I implemented _tfr_from_mt
in mne-tools/mne-python#12910 works fine there, but isn't flexible to an epochs dimension (in contrast _psd_from_mt
is).
Would people be open to a small PR in the main package to fix this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure!
I've addressed the previous issues. There should now be full support for |
Given that the multitaper TFR requires the unreleased MNE 1.10, is it worth holding off on merging (once it's checked and ready to go) until after MNE-Conn 0.8 is released? On that note, since the last major part of the GSoC project in #223 was merged, there's a decent amount for a new release: https://mne.tools/mne-connectivity/dev/whats_new.html Opened an issue here -> #276 |
You could, but even then you'll need to add a version guard to the mne-connectivity that checks the MNE version. You have them in tests but you also need them in the user-facing function. You can use something like |
... although I think the
in the user-facing code |
Currently I'm checking compatibility based on the object attrs. Do you think an explicit version check is cleaner? mne-connectivity/mne_connectivity/spectral/epochs.py Lines 1225 to 1233 in 73ce0a6
|
Is the issue just about which MNE version is currently installed? Or is it to do with which version of MNE the data were saved to disk with? If the latter plays a role, then attr-based checking seems right to me (though I'd tweak the error message to explicitly reference the possibility of loaded data that were saved by an older version of MNE). On the other hand if it's just about the currently installed version of MNE, then checking that (rather than object attrs) makes for easier-to-maintain code IMO (as the code more directly reflects the source of the problem). |
Moreso the latter, whatever MNE version was used to compute the coefficients.
Yeah good call, I'll add this. |
@larsoner this one LGTM, I'll let you merge if happy |
if not is_tfr_con: # normal spectra (multitaper or Fourier) | ||
this_psd = _psd_from_mt(x_t, weights) | ||
else: # TFR spectra (multitaper) | ||
this_psd = np.array([_tfr_from_mt(epo_x, weights) for epo_x in x_t]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure!
def _tfr_csd_from_mt(x_mt, y_mt, weights_x, weights_y): | ||
"""Compute time-frequency CSD from tapered spectra. | ||
|
||
Parameters | ||
---------- | ||
x_mt : array, shape (..., n_tapers, n_freqs, n_times) | ||
The tapered time-frequency spectra for signals x. | ||
y_mt : array, shape (..., n_tapers, n_freqs, n_times) | ||
The tapered time-frequency spectra for signals y. | ||
weights_x : array, shape (n_tapers, n_freqs) | ||
Weights to use for combining the tapered spectra of x_mt. | ||
weights_y : array, shape (n_tapers, n_freqs) | ||
Weights to use for combining the tapered spectra of y_mt. | ||
|
||
Returns | ||
------- | ||
csd : array, shape (..., n_freqs, n_times) | ||
The CSD between x and y. | ||
""" | ||
# expand weights dims to match x_mt and y_mt | ||
weights_x = np.expand_dims(weights_x, axis=(*np.arange(x_mt.ndim - 3), -1)) | ||
weights_y = np.expand_dims(weights_y, axis=(*np.arange(y_mt.ndim - 3), -1)) | ||
# compute CSD | ||
csd = np.sum(weights_x * x_mt * (weights_y * y_mt).conj(), axis=-3) | ||
denom = np.sqrt((weights_x * weights_x.conj()).real.sum(axis=-3)) * np.sqrt( | ||
(weights_y * weights_y.conj()).real.sum(axis=-3) | ||
) | ||
csd *= 2 / denom | ||
return csd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think even if you put it there you'd need some sort of try/except
based on MNE version with a fallback to this code (in this codebase) anyway.
Thanks @tsbinns ! |
WIP follow up of #225 to finalise GSoC work.
So far adds support for coefficients from
Epochs.compute_tfr(method="morlet", output="complex")
tospectral_connectivity_epochs()
, equivalent to thecwt_morlet
mode. Still to be done is adding support forEpochsTFR
objects inspectral_connectivity_time()
.In addition to the Morlet approach,
spectral_connectivity_time()
supports the time-freq. multitaper mode. #126 could also be addressed by adding support for this tospectral_connectivity_epochs()
, in the form ofEpochsTFR
objects.However, when trying this I discovered a bug that prevents the time-freq. multitaper mode being computed from epoched data (mne-tools/mne-python#12831), but that seems like an easy fix.
Will continue to work on this next week.