Skip to content

Commit 48eb2d4

Browse files
[FIX] dss_line() bug when using pca (#11)
1 parent 73fea33 commit 48eb2d4

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

meegkit/dss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def dss_line(x, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False):
194194
if nkeep is not None:
195195
xxx_cov = tscov(xxx)[0]
196196
V, _ = pca(xxx_cov, nkeep)
197-
xxxx = xxx * V
197+
xxxx = xxx @ V
198198
else:
199199
xxxx = xxx.copy()
200200

tests/test_dss.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,15 @@ def test_dss1(show=False):
128128
atol=1e-6) # use abs as DSS component might be flipped
129129

130130

131-
def test_dss_line():
131+
@pytest.mark.parametrize('nkeep', [None, 2])
132+
def test_dss_line(nkeep):
132133
"""Test line noise removal."""
133134
sr = 200
135+
fline = 20
134136
nsamples = 10000
135137
nchans = 10
136138
x = np.random.randn(nsamples, nchans)
137-
artifact = np.sin(np.arange(nsamples) / sr * 2 * np.pi * 20)[:, None]
139+
artifact = np.sin(np.arange(nsamples) / sr * 2 * np.pi * fline)[:, None]
138140
artifact[artifact < 0] = 0
139141
artifact = artifact ** 3
140142
s = x + 10 * artifact
@@ -155,24 +157,25 @@ def _plot(x):
155157
plt.show()
156158

157159
# 2D case, n_outputs == 1
158-
out, _ = dss.dss_line(s, 20, sr)
160+
out, _ = dss.dss_line(s, fline, sr, nkeep=nkeep)
159161
_plot(out)
160162

161163
# Test n_outputs > 1
162-
out, _ = dss.dss_line(s, 20, sr, nremove=2)
164+
out, _ = dss.dss_line(s, fline, sr, nkeep=nkeep, nremove=2)
163165
# _plot(out)
164166

165167
# Test n_trials > 1
166168
x = np.random.randn(nsamples, nchans, 4)
167-
artifact = np.sin(np.arange(nsamples) / sr * 2 * np.pi * 20)[:, None, None]
169+
artifact = np.sin(
170+
np.arange(nsamples) / sr * 2 * np.pi * fline)[:, None, None]
168171
artifact[artifact < 0] = 0
169172
artifact = artifact ** 3
170173
s = x + 10 * artifact
171-
out, _ = dss.dss_line(s, 20, sr, nremove=1)
174+
out, _ = dss.dss_line(s, fline, sr, nremove=1)
172175

173176

174177
if __name__ == '__main__':
175178
pytest.main([__file__])
176179
# create_data(SNR=5, show=True)
177180
# test_dss1(True)
178-
# test_dss_line()
181+
# test_dss_line(None)

0 commit comments

Comments
 (0)