@@ -128,13 +128,15 @@ def test_dss1(show=False):
128
128
atol = 1e-6 ) # use abs as DSS component might be flipped
129
129
130
130
131
- def test_dss_line ():
131
+ @pytest .mark .parametrize ('nkeep' , [None , 2 ])
132
+ def test_dss_line (nkeep ):
132
133
"""Test line noise removal."""
133
134
sr = 200
135
+ fline = 20
134
136
nsamples = 10000
135
137
nchans = 10
136
138
x = np .random .randn (nsamples , nchans )
137
- artifact = np .sin (np .arange (nsamples ) / sr * 2 * np .pi * 20 )[:, None ]
139
+ artifact = np .sin (np .arange (nsamples ) / sr * 2 * np .pi * fline )[:, None ]
138
140
artifact [artifact < 0 ] = 0
139
141
artifact = artifact ** 3
140
142
s = x + 10 * artifact
@@ -155,24 +157,25 @@ def _plot(x):
155
157
plt .show ()
156
158
157
159
# 2D case, n_outputs == 1
158
- out , _ = dss .dss_line (s , 20 , sr )
160
+ out , _ = dss .dss_line (s , fline , sr , nkeep = nkeep )
159
161
_plot (out )
160
162
161
163
# Test n_outputs > 1
162
- out , _ = dss .dss_line (s , 20 , sr , nremove = 2 )
164
+ out , _ = dss .dss_line (s , fline , sr , nkeep = nkeep , nremove = 2 )
163
165
# _plot(out)
164
166
165
167
# Test n_trials > 1
166
168
x = np .random .randn (nsamples , nchans , 4 )
167
- artifact = np .sin (np .arange (nsamples ) / sr * 2 * np .pi * 20 )[:, None , None ]
169
+ artifact = np .sin (
170
+ np .arange (nsamples ) / sr * 2 * np .pi * fline )[:, None , None ]
168
171
artifact [artifact < 0 ] = 0
169
172
artifact = artifact ** 3
170
173
s = x + 10 * artifact
171
- out , _ = dss .dss_line (s , 20 , sr , nremove = 1 )
174
+ out , _ = dss .dss_line (s , fline , sr , nremove = 1 )
172
175
173
176
174
177
if __name__ == '__main__' :
175
178
pytest .main ([__file__ ])
176
179
# create_data(SNR=5, show=True)
177
180
# test_dss1(True)
178
- # test_dss_line()
181
+ # test_dss_line(None )
0 commit comments