Skip to content

Commit 776eb39

Browse files
committed
add example + add asr to doc
1 parent 0424f78 commit 776eb39

File tree

5 files changed

+71
-1
lines changed

5 files changed

+71
-1
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ install:
1616
- pip install sphinx-gallery
1717
- pip install numpydoc
1818
- pip install sphinx_bootstrap_theme
19+
- pip install git+https://github.com/pymanopt/pymanopt
1920
script:
2021
- mkdir docs
2122
- cd doc

doc/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'config.py']
5656

5757
# generate autosummary even if no references
58-
# autosummary_generate = True
58+
autosummary_generate = True
5959
autodoc_default_flags = ['members', 'undoc-members', 'show-inheritance', 'inherited-members']
6060
numpydoc_show_class_members = True
6161

doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Contents
1919
.. autosummary::
2020
:toctree: modules/
2121

22+
~meegkit.asr
2223
~meegkit.cca
2324
~meegkit.dss
2425
~meegkit.detrend

examples/example_asr.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
ASR example
3+
===========
4+
5+
Denoise data using Artifact Subspace Reconstruction.
6+
7+
Uses meegkit.ASR().
8+
"""
9+
import os
10+
import numpy as np
11+
import matplotlib.pyplot as plt
12+
13+
from meegkit.asr import ASR
14+
from meegkit.utils.asr import yulewalk_filter
15+
from meegkit.utils.matrix import sliding_window
16+
17+
# THIS_FOLDER = os.path.dirname(os.path.abspath(__file__))
18+
raw = np.load(os.path.join('..', 'tests', 'data', 'eeg_raw.npy'))
19+
sfreq = 250
20+
21+
###############################################################################
22+
# Calibration and processing
23+
# -----------------------------------------------------------------------------
24+
25+
# Train on a clean portion of data
26+
asr = ASR(method='euclid')
27+
train_idx = np.arange(0 * sfreq, 30 * sfreq, dtype=int)
28+
_, sample_mask = asr.fit(raw[:, train_idx])
29+
30+
# Apply filter using sliding (non-overlapping) windows
31+
X = sliding_window(raw, window=int(sfreq), step=int(sfreq))
32+
Y = np.zeros_like(X)
33+
for i in range(X.shape[1]):
34+
Y[:, i, :] = asr.transform(X[:, i, :])
35+
36+
raw = X.reshape(8, -1) # reshape to (n_chans, n_times)
37+
clean = Y.reshape(8, -1)
38+
39+
###############################################################################
40+
# Plot the results
41+
# -----------------------------------------------------------------------------
42+
#
43+
# Data was trained on a 40s window from 5s to 45s onwards (gray filled area).
44+
# The algorithm then removes portions of this data with high amplitude
45+
# artifacts before running the calibration (hatched area = good).
46+
47+
times = np.arange(raw.shape[-1]) / sfreq
48+
f, ax = plt.subplots(8, sharex=True, figsize=(8, 5))
49+
for i in range(8):
50+
ax[i].fill_between(train_idx / sfreq, 0, 1, color='grey', alpha=.3,
51+
transform=ax[i].get_xaxis_transform(),
52+
label='calibration window')
53+
ax[i].fill_between(train_idx / sfreq, 0, 1, where=sample_mask.flat,
54+
transform=ax[i].get_xaxis_transform(),
55+
facecolor='none', hatch='...', edgecolor='k',
56+
label='selected window')
57+
ax[i].plot(times, raw[i], lw=.5, label='before ASR')
58+
ax[i].plot(times, clean[i], label='after ASR', lw=.5)
59+
ax[i].set_ylim([-50, 50])
60+
ax[i].set_ylabel(f'ch{i}')
61+
ax[i].set_yticks([])
62+
ax[i].set_xlabel('Time (s)')
63+
ax[0].legend(fontsize='small', bbox_to_anchor=(1.04, 1), borderaxespad=0)
64+
plt.subplots_adjust(hspace=0, right=0.75)
65+
plt.suptitle('Before/after ASR')
66+
plt.show()

meegkit/asr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ def fit(self, X, y=None, **kwargs):
191191
self.state_ = dict(M=M, T=T, R=None)
192192
self._fitted = True
193193

194+
return clean, sample_mask
195+
194196
def transform(self, X, y=None, **kwargs):
195197
"""Apply Artifact Subspace Reconstruction.
196198

0 commit comments

Comments
 (0)