-
Notifications
You must be signed in to change notification settings - Fork 38
Phase-slope index using spectral_connectivity_time instead of spectral_connectivity_epochs #210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 19 commits
2631a83
388ab8c
c38374a
b44da50
75a72b5
6cef3cb
dffda01
ef0a484
be0c3cd
ce68935
401a6a8
573211d
a96db00
ac393fe
b5ab63b
3c93427
5dad254
8820b6f
fd5738e
48e81a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,8 +7,12 @@ | |
| import numpy as np | ||
| from mne.utils import logger, verbose | ||
|
|
||
| from .base import SpectralConnectivity, SpectroTemporalConnectivity | ||
| from .spectral import spectral_connectivity_epochs | ||
| from .base import ( | ||
| EpochSpectralConnectivity, | ||
| SpectralConnectivity, | ||
| SpectroTemporalConnectivity, | ||
| ) | ||
| from .spectral import spectral_connectivity_epochs, spectral_connectivity_time | ||
| from .utils import fill_doc | ||
|
|
||
|
|
||
|
|
@@ -240,3 +244,192 @@ def phase_slope_index( | |
| ) | ||
|
|
||
| return conn | ||
|
|
||
|
|
||
| @verbose | ||
| @fill_doc | ||
| def phase_slope_index_time( | ||
| data, | ||
| freqs, | ||
| indices=None, | ||
| sfreq=2 * np.pi, | ||
| mode="cwt_morlet", | ||
| fmin=None, | ||
| fmax=None, | ||
| padding=0, | ||
| mt_bandwidth=None, | ||
| n_cycles=7, | ||
| n_jobs=1, | ||
| verbose=None, | ||
| ): | ||
| """Compute the Phase Slope Index (PSI) connectivity measure across time. | ||
|
|
||
| This function computes PSI over time from epoched data. The data may consist of a | ||
| single epoch. | ||
|
|
||
| The PSI is an effective connectivity measure, i.e., a measure which can give an | ||
| indication of the direction of the information flow (causality). For two time | ||
| series, one computes the PSI between the first and the second time series as | ||
| follows: :: | ||
|
|
||
| indices = (np.array([0]), np.array([1])) | ||
| psi = phase_slope_index(data, indices=indices, ...) | ||
|
|
||
| A positive value means that time series 0 is ahead of time series 1 and a negative | ||
| value means the opposite. | ||
|
|
||
| The PSI is computed from the coherency (see :func:`spectral_connectivity_time`), | ||
| details can be found in :footcite:`NolteEtAl2008`. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : array-like, shape (n_epochs, n_signals, n_times) | Epochs | ||
| The data from which to compute connectivity. | ||
| freqs : array-like | ||
| Array of frequencies of interest for time-frequency decomposition. Only the | ||
| frequencies within the range specified by ``fmin`` and ``fmax`` are used. | ||
| indices : tuple of array | None | ||
| Two arrays with indices of connections for which to compute connectivity. If | ||
| `None`, all connections are computed. | ||
| sfreq : float | ||
| The sampling frequency. Required if data is not :class:`~mne.Epochs`. | ||
| mode : str | ||
| Time-frequency decomposition method. Can be either: 'multitaper' or | ||
| 'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and | ||
| :func:`mne.time_frequency.tfr_array_morlet` for reference. | ||
| fmin : float | tuple of float | None | ||
| The lower frequency of interest. Multiple bands are defined using a tuple, e.g., | ||
| ``(8., 20.)`` for two bands with 8 Hz and 20 Hz lower bounds. If `None`, the | ||
| lowest frequency in ``freqs`` is used. | ||
| fmax : float | tuple of float | None | ||
| The upper frequency of interest. Multiple bands are defined using a tuple, e.g. | ||
| ``(13., 30.)`` for two band with 13 Hz and 30 Hz upper bounds. If `None`, the | ||
| highest frequency in ``freqs`` is used. | ||
| padding : float | ||
| Amount of time to consider as padding at the beginning and end of each epoch in | ||
| seconds. See Notes of :func:`spectral_connectivity_time` for more information. | ||
| mt_bandwidth : float | None | ||
| The bandwidth of the multitaper windowing function in Hz. Only used if | ||
| ``mode='multitaper'``. | ||
| n_cycles : float | array-like of float | ||
| Number of cycles. Fixed number or one per frequency. Only used if | ||
| ``mode='cwt_morlet'``. | ||
| n_jobs : int | ||
| Number of connections to compute in parallel. Memory mapping must be activated. | ||
| Please see the Notes section of :func:`spectral_connectivity_time` for details. | ||
| %(verbose)s | ||
|
|
||
| Returns | ||
| ------- | ||
| conn : instance of EpochSpectralConnectivity | ||
| Computed connectivity measure. An instance of | ||
| :class:`EpochSpectralConnectivity`. The shape of the connectivity dataset is | ||
| ``(n_epochs, n_cons, n_bands)``. When ``indices`` is `None`, | ||
| ``n_cons = n_signals ** 2``. When ``indices`` is specified, | ||
| ``n_cons = len(indices[0])``. | ||
|
|
||
| See Also | ||
| -------- | ||
| mne_connectivity.EpochSpectralConnectivity | ||
tsbinns marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| mne_connectivity.spectral_connectivity_time | ||
|
|
||
| References | ||
| ---------- | ||
| .. footbibliography:: | ||
| """ | ||
| logger.info("Estimating phase slope index (PSI) across time") | ||
|
|
||
| # estimate the coherency | ||
| cohy = spectral_connectivity_time( | ||
| data, | ||
| freqs=freqs, | ||
| method="cohy", | ||
| average=False, | ||
| indices=indices, | ||
| sfreq=sfreq, | ||
| fmin=fmin, | ||
| fmax=fmax, | ||
| fskip=0, | ||
| faverage=False, | ||
| sm_times=0, | ||
| sm_freqs=1, | ||
| sm_kernel="hanning", | ||
| padding=padding, | ||
|
Comment on lines
+354
to
+357
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wonder if we should expose more of these options to the user, e.g. why just |
||
| mode=mode, | ||
| mt_bandwidth=mt_bandwidth, | ||
| n_cycles=n_cycles, | ||
| decim=1, | ||
| n_jobs=n_jobs, | ||
| verbose=verbose, | ||
| ) | ||
|
|
||
| freqs_ = np.array(cohy.freqs) | ||
| names = cohy.names | ||
| n_tapers = cohy.attrs.get("n_tapers") | ||
| n_nodes = cohy.n_nodes | ||
| metadata = cohy.metadata | ||
| events = cohy.events | ||
| event_id = cohy.event_id | ||
|
|
||
| logger.info(f"Computing PSI from estimated Coherency: {cohy}") | ||
| # compute PSI in the requested bands | ||
| if fmin is None: | ||
| fmin = -np.inf | ||
| if fmax is None: | ||
| fmax = np.inf | ||
|
|
||
| bands = list(zip(np.asarray((fmin,)).ravel(), np.asarray((fmax,)).ravel())) | ||
| n_bands = len(bands) | ||
|
|
||
| freq_dim = -1 | ||
|
|
||
| # allocate space for output | ||
| out_shape = list(cohy.shape) | ||
| out_shape[freq_dim] = n_bands | ||
| psi = np.zeros(out_shape, dtype=np.float64) | ||
|
|
||
| # allocate accumulator | ||
| acc_shape = copy.copy(out_shape) | ||
| acc_shape.pop(freq_dim) | ||
| acc = np.empty(acc_shape, dtype=np.complex128) | ||
|
|
||
| # create list for frequencies used and frequency bands | ||
| # of resulting connectivity data | ||
| freqs = list() | ||
| freq_bands = list() | ||
| idx_fi = [slice(None)] * len(out_shape) | ||
| idx_fj = [slice(None)] * len(out_shape) | ||
| for band_idx, band in enumerate(bands): | ||
| freq_idx = np.where((freqs_ > band[0]) & (freqs_ < band[1]))[0] | ||
| freqs.append(freqs_[freq_idx]) | ||
| freq_bands.append(np.mean(freqs_[freq_idx])) | ||
|
|
||
| acc.fill(0.0) | ||
| for fi, fj in zip(freq_idx, freq_idx[1:]): | ||
| idx_fi[freq_dim] = fi | ||
| idx_fj[freq_dim] = fj | ||
| acc += ( | ||
| np.conj(cohy.get_data()[tuple(idx_fi)]) * cohy.get_data()[tuple(idx_fj)] | ||
| ) | ||
|
|
||
| idx_fi[freq_dim] = band_idx | ||
| psi[tuple(idx_fi)] = np.imag(acc) | ||
| logger.info("[PSI Estimation Done]") | ||
|
|
||
| # create a connectivity container | ||
| conn = EpochSpectralConnectivity( | ||
| data=psi, | ||
| names=names, | ||
| freqs=freq_bands, | ||
| n_nodes=n_nodes, | ||
| method="phase-slope-index", | ||
| spec_method=mode, | ||
| indices=indices, | ||
| freqs_computed=freqs, | ||
| n_tapers=n_tapers, | ||
| metadata=metadata, | ||
| events=events, | ||
| event_id=event_id, | ||
| ) | ||
|
|
||
| return conn | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| import numpy as np | ||
| from numpy.testing import assert_array_almost_equal | ||
|
|
||
| from mne_connectivity.effective import phase_slope_index | ||
| from mne_connectivity.effective import phase_slope_index, phase_slope_index_time | ||
|
|
||
|
|
||
| def test_psi(): | ||
|
|
@@ -39,3 +39,40 @@ def test_psi(): | |
|
|
||
| assert np.all(conn_cwt.get_data() > 0) | ||
| assert conn_cwt.shape[-1] == n_times | ||
|
|
||
|
|
||
| def test_psi_time(): | ||
| """Test Phase Slope Index (PSI) estimation across time.""" | ||
tsbinns marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| sfreq = 50.0 | ||
| n_signals = 3 | ||
| n_epochs = 10 | ||
| n_times = 500 | ||
| rng = np.random.RandomState(42) | ||
| data = rng.randn(n_epochs, n_signals, n_times) | ||
|
|
||
| # simulate time shifts | ||
| for i in range(n_epochs): | ||
| data[i, 1, 10:] = data[i, 0, :-10] # signal 0 is ahead | ||
| data[i, 2, :-10] = data[i, 0, 10:] # signal 2 is ahead | ||
|
|
||
| # conn = phase_slope_index_time(data, mode="fourier", sfreq=sfreq, freqs=np.arange(4)) | ||
|
|
||
| # assert conn.get_data(output="dense")[1, 0, 0] < 0 | ||
| # assert conn.get_data(output="dense")[2, 0, 0] > 0 | ||
|
|
||
| # # only compute for a subset of the indices | ||
| indices = (np.array([0]), np.array([1])) | ||
| # conn_2 = phase_slope_index_time(data, mode="fourier", sfreq=sfreq, freqs=np.arange(4), indices=indices) | ||
|
|
||
| # # the measure is symmetric (sign flip) | ||
| # assert_array_almost_equal( | ||
| # conn_2.get_data()[0, 0], -conn.get_data(output="dense")[1, 0, 0] | ||
| # ) | ||
|
|
||
| freqs = np.arange(5.0, 20, 0.5) | ||
| conn_cwt = phase_slope_index_time( | ||
| data, mode="cwt_morlet", sfreq=sfreq, freqs=freqs, indices=indices | ||
| ) | ||
|
Comment on lines
+58
to
+75
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Including some of the things commented here would be nice to better reflect the checks happening for the regular PSI function. So as well as checking that seed -> target connectivity is > 0, should also check that target -> seed connectivity is < 0 and they are identical but just sign-flipped. |
||
|
|
||
| assert np.all(conn_cwt.get_data() > 0) | ||
| assert conn_cwt.shape[0] == n_epochs | ||
Uh oh!
There was an error while loading. Please reload this page.