Skip to content

Commit d01afe0

Browse files
authored
ENH: allow specifying ransac at class inst of NoisyChannels (#164)
* ENH: allow specifying ransac at class inst of NoisyChannels * update pre-commit * fix missing line
1 parent 5fa005b commit d01afe0

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@ repos:
1414
- id: check-docstring-first
1515

1616
- repo: https://github.com/astral-sh/ruff-pre-commit
17-
rev: v0.8.0
17+
rev: v0.12.3
1818
hooks:
1919
- id: ruff
20+
name: ruff check --fix
21+
files: ^(pyprep/|examples/|tests/)
2022
args: ["--fix"]
2123
- id: ruff-format
24+
name: ruff format
25+
files: ^(pyprep/|examples/|tests/)
2226

2327
- repo: https://github.com/pappasam/toml-sort
2428
rev: v0.24.2

pyprep/find_noisy_channels.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""finds bad channels."""
1+
"""Find bad channels."""
22

33
# Authors: The PyPREP developers
44
# SPDX-License-Identifier: MIT
@@ -44,6 +44,11 @@ class NoisyChannels:
4444
Whether or not PyPREP should strictly follow MATLAB PREP's internal
4545
math, ignoring any improvements made in PyPREP over the original code
4646
(see :ref:`matlab-diffs` for more details). Defaults to ``False``.
47+
ransac : bool
48+
Whether RANSAC should be used for bad channel detection, in addition
49+
to other methods. RANSAC can detect bad channels that other
50+
methods are unable to catch, but also slows down noisy channel
51+
detection considerably. Defaults to ``True``.
4752
4853
References
4954
----------
@@ -53,7 +58,15 @@ class NoisyChannels:
5358
5459
"""
5560

56-
def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False):
61+
def __init__(
62+
self,
63+
raw,
64+
do_detrend=True,
65+
random_state=None,
66+
matlab_strict=False,
67+
*,
68+
ransac=True,
69+
):
5770
# Make sure that we got an MNE object
5871
assert isinstance(raw, mne.io.BaseRaw)
5972

@@ -68,6 +81,9 @@ def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False)
6881
)
6982
self.matlab_strict = matlab_strict
7083

84+
assert isinstance(ransac, bool), f"ransac must be boolean, got: {ransac}"
85+
self.ransac = ransac
86+
7187
# Extra data for debugging
7288
self._extra_info = {
7389
"bad_by_deviation": {},
@@ -187,18 +203,20 @@ def get_bads(self, verbose=False, as_dict=False):
187203

188204
return bads
189205

190-
def find_all_bads(self, ransac=True, channel_wise=False, max_chunk_size=None):
206+
def find_all_bads(self, ransac=None, channel_wise=False, max_chunk_size=None):
191207
"""Call all the functions to detect bad channels.
192208
193209
This function calls all the bad-channel detecting functions.
194210
195211
Parameters
196212
----------
197-
ransac : bool, optional
213+
ransac : bool | None
198214
Whether RANSAC should be used for bad channel detection, in addition
199215
to the other methods. RANSAC can detect bad channels that other
200216
methods are unable to catch, but also slows down noisy channel
201-
detection considerably. Defaults to ``True``.
217+
detection considerably. If ``None`` (default), then the value at
218+
instantiation of the ``NoisyChannels`` class is taken (defaults
219+
to ``True``), else the instantiation value is overwritten.
202220
channel_wise : bool, optional
203221
Whether RANSAC should predict signals for chunks of channels over the
204222
entire signal length ("channel-wise RANSAC", see `max_chunk_size`
@@ -218,12 +236,20 @@ def find_all_bads(self, ransac=True, channel_wise=False, max_chunk_size=None):
218236
effect. Defaults to ``None``.
219237
220238
"""
239+
if ransac is not None and ransac != self.ransac:
240+
assert isinstance(ransac, bool), f"ransac must be boolean, got: {ransac}"
241+
logger.warning(
242+
f"Overwriting `ransac` value. Was `{self.ransac}` at instantiation "
243+
f"of NoisyChannels. Now setting to `{ransac}`."
244+
)
245+
self.ransac = ransac
246+
221247
# NOTE: Bad-by-NaN/flat is already run during init, no need to re-run here
222248
self.find_bad_by_deviation()
223249
self.find_bad_by_hfnoise()
224250
self.find_bad_by_correlation()
225251
self.find_bad_by_SNR()
226-
if ransac:
252+
if self.ransac:
227253
self.find_bad_by_ransac(
228254
channel_wise=channel_wise, max_chunk_size=max_chunk_size
229255
)

0 commit comments

Comments
 (0)