Skip to content

Commit aa6b0d3

Browse files
author
Ashley Scillitoe
committed
Rename rest to reset_state
1 parent feb9fda commit aa6b0d3

15 files changed

+62
-63
lines changed

alibi_detect/cd/base_online.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from pathlib import Path
33
import logging
4+
import warnings
45
from abc import abstractmethod
56
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
67

@@ -219,6 +220,15 @@ def _initialise_state(self) -> None:
219220
self.drift_preds = np.array([]) # type: ignore[var-annotated]
220221

221222
def reset(self) -> None:
223+
"""
224+
Deprecated reset method. This method will be repurposed or removed in the future. To reset the detector to
225+
its initial state (`t=0`) use :meth:`reset_state`.
226+
"""
227+
self.reset_state()
228+
warnings.warn('This method is deprecated and will be removed/repurposed in the future. To reset the detector '
229+
'to its initial state use `reset_state`.', DeprecationWarning)
230+
231+
def reset_state(self) -> None:
222232
"""
223233
Resets the detector to its initial state (`t=0`). This does not include reconfiguring thresholds.
224234
"""
@@ -524,6 +534,15 @@ def _check_drift(self, test_stats: np.ndarray, thresholds: np.ndarray) -> int:
524534
pass
525535

526536
def reset(self) -> None:
537+
"""
538+
Deprecated reset method. This method will be repurposed or removed in the future. To reset the detector to
539+
its initial state (`t=0`) use :meth:`reset_state`.
540+
"""
541+
self.reset_state()
542+
warnings.warn('This method is deprecated and will be removed/repurposed in the future. To reset the detector '
543+
'to its initial state use `reset_state`.', DeprecationWarning)
544+
545+
def reset_state(self) -> None:
527546
"""
528547
Resets the detector to its initial state (`t=0`). This does not include reconfiguring thresholds.
529548
"""

alibi_detect/cd/lsdd_online.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,11 @@ def test_stats(self):
113113
def thresholds(self):
114114
return [self._detector.thresholds[min(s, self._detector.window_size-1)] for s in range(self.t)]
115115

116-
def reset(self):
117-
"Resets the detector but does not reconfigure thresholds."
118-
self._detector.reset()
116+
def reset_state(self):
117+
"""
118+
Resets the detector to its initial state (`t=0`). This does not include reconfiguring thresholds.
119+
"""
120+
self._detector.reset_state()
119121

120122
def predict(self, x_t: Union[np.ndarray, Any], return_test_stat: bool = True) \
121123
-> Dict[Dict[str, str], Dict[str, Union[int, float]]]:

alibi_detect/cd/mmd_online.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,11 @@ def test_stats(self):
113113
def thresholds(self):
114114
return [self._detector.thresholds[min(s, self._detector.window_size-1)] for s in range(self.t)]
115115

116-
def reset(self):
117-
"Resets the detector but does not reconfigure thresholds."
118-
self._detector.reset()
116+
def reset_state(self):
117+
"""
118+
Resets the detector to its initial state (`t=0`). This does not include reconfiguring thresholds.
119+
"""
120+
self._detector.reset_state()
119121

120122
def predict(self, x_t: Union[np.ndarray, Any], return_test_stat: bool = True) \
121123
-> Dict[Dict[str, str], Dict[str, Union[int, float]]]:

alibi_detect/cd/pytorch/tests/test_lsdd_online_pt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,12 @@ def test_lsdd_online(lsdd_online_params, seed):
8787
test_stats_h0.append(pred_t['data']['test_stat'])
8888
if pred_t['data']['is_drift']:
8989
detection_times_h0.append(pred_t['data']['time'])
90-
cd.reset()
90+
cd.reset_state()
9191
average_delay_h0 = np.array(detection_times_h0).mean()
9292
test_stats_h0 = [ts for ts in test_stats_h0 if ts is not None]
9393
assert ert/3 < average_delay_h0 < 3*ert
9494

95-
cd.reset()
95+
cd.reset_state()
9696

9797
detection_times_h1 = []
9898
test_stats_h1 = []
@@ -103,7 +103,7 @@ def test_lsdd_online(lsdd_online_params, seed):
103103
test_stats_h1.append(pred_t['data']['test_stat'])
104104
if pred_t['data']['is_drift']:
105105
detection_times_h1.append(pred_t['data']['time'])
106-
cd.reset()
106+
cd.reset_state()
107107
average_delay_h1 = np.array(detection_times_h1).mean()
108108
test_stats_h1 = [ts for ts in test_stats_h1 if ts is not None]
109109
assert np.abs(average_delay_h1) < ert/2
@@ -139,7 +139,7 @@ def test_lsdd_online_state_online(tmp_path, seed):
139139
test_stats_1.append(preds['data']['test_stat'])
140140

141141
# Reset and check state cleared
142-
dd.reset()
142+
dd.reset_state()
143143
for key, orig_val in state_dict_t0.items():
144144
np.testing.assert_array_equal(orig_val, getattr(dd, key)) # use np.testing here as it handles torch.Tensor etc
145145

alibi_detect/cd/pytorch/tests/test_mmd_online_pt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,12 @@ def test_mmd_online(mmd_online_params, seed):
8787
test_stats_h0.append(pred_t['data']['test_stat'])
8888
if pred_t['data']['is_drift']:
8989
detection_times_h0.append(pred_t['data']['time'])
90-
cd.reset()
90+
cd.reset_state()
9191
average_delay_h0 = np.array(detection_times_h0).mean()
9292
test_stats_h0 = [ts for ts in test_stats_h0 if ts is not None]
9393
assert ert/3 < average_delay_h0 < 3*ert
9494

95-
cd.reset()
95+
cd.reset_state()
9696

9797
detection_times_h1 = []
9898
test_stats_h1 = []
@@ -103,7 +103,7 @@ def test_mmd_online(mmd_online_params, seed):
103103
test_stats_h1.append(pred_t['data']['test_stat'])
104104
if pred_t['data']['is_drift']:
105105
detection_times_h1.append(pred_t['data']['time'])
106-
cd.reset()
106+
cd.reset_state()
107107
average_delay_h1 = np.array(detection_times_h1).mean()
108108
test_stats_h1 = [ts for ts in test_stats_h1 if ts is not None]
109109
assert np.abs(average_delay_h1) < ert/2
@@ -139,7 +139,7 @@ def test_mmd_online_state_online(tmp_path, seed):
139139
test_stats_1.append(preds['data']['test_stat'])
140140

141141
# Reset and check state cleared
142-
dd.reset()
142+
dd.reset_state()
143143
for key, orig_val in state_dict_t0.items():
144144
np.testing.assert_array_equal(orig_val, getattr(dd, key)) # use np.testing here as it handles torch.Tensor etc
145145

alibi_detect/cd/tensorflow/tests/test_lsdd_online_tf.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,12 @@ def test_lsdd_online(lsdd_online_params, seed):
9898
test_stats_h0.append(pred_t['data']['test_stat'])
9999
if pred_t['data']['is_drift']:
100100
detection_times_h0.append(pred_t['data']['time'])
101-
cd.reset()
101+
cd.reset_state()
102102
average_delay_h0 = np.array(detection_times_h0).mean()
103103
test_stats_h0 = [ts for ts in test_stats_h0 if ts is not None]
104104
assert ert/3 < average_delay_h0 < 3*ert
105105

106-
cd.reset()
106+
cd.reset_state()
107107

108108
detection_times_h1 = []
109109
test_stats_h1 = []
@@ -114,7 +114,7 @@ def test_lsdd_online(lsdd_online_params, seed):
114114
test_stats_h1.append(pred_t['data']['test_stat'])
115115
if pred_t['data']['is_drift']:
116116
detection_times_h1.append(pred_t['data']['time'])
117-
cd.reset()
117+
cd.reset_state()
118118
average_delay_h1 = np.array(detection_times_h1).mean()
119119
test_stats_h1 = [ts for ts in test_stats_h1 if ts is not None]
120120
assert np.abs(average_delay_h1) < ert/2
@@ -150,7 +150,7 @@ def test_lsdd_online_state_online(tmp_path, seed):
150150
test_stats_1.append(preds['data']['test_stat'])
151151

152152
# Reset and check state cleared
153-
dd.reset()
153+
dd.reset_state()
154154
for key, orig_val in state_dict_t0.items():
155155
np.testing.assert_array_equal(orig_val, getattr(dd, key)) # use np.testing here as it handles torch.Tensor etc
156156

alibi_detect/cd/tensorflow/tests/test_mmd_online_tf.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,12 @@ def test_mmd_online(mmd_online_params, seed):
9898
test_stats_h0.append(pred_t['data']['test_stat'])
9999
if pred_t['data']['is_drift']:
100100
detection_times_h0.append(pred_t['data']['time'])
101-
cd.reset()
101+
cd.reset_state()
102102
average_delay_h0 = np.array(detection_times_h0).mean()
103103
test_stats_h0 = [ts for ts in test_stats_h0 if ts is not None]
104104
assert ert/3 < average_delay_h0 < 3*ert
105105

106-
cd.reset()
106+
cd.reset_state()
107107

108108
detection_times_h1 = []
109109
test_stats_h1 = []
@@ -114,7 +114,7 @@ def test_mmd_online(mmd_online_params, seed):
114114
test_stats_h1.append(pred_t['data']['test_stat'])
115115
if pred_t['data']['is_drift']:
116116
detection_times_h1.append(pred_t['data']['time'])
117-
cd.reset()
117+
cd.reset_state()
118118
average_delay_h1 = np.array(detection_times_h1).mean()
119119
print(detection_times_h0, average_delay_h0)
120120
test_stats_h1 = [ts for ts in test_stats_h1 if ts is not None]
@@ -151,7 +151,7 @@ def test_mmd_online_state_online(tmp_path, seed):
151151
test_stats_1.append(preds['data']['test_stat'])
152152

153153
# Reset and check state cleared
154-
dd.reset()
154+
dd.reset_state()
155155
for key, orig_val in state_dict_t0.items():
156156
np.testing.assert_array_equal(orig_val, getattr(dd, key)) # use np.testing here as it handles torch.Tensor etc
157157

alibi_detect/cd/tests/test_cvm_online.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,21 @@ def test_cvmdriftonline(window_sizes, batch_size, n_feat, seed):
3939
test_stats_h0.append(pred_t['data']['test_stat'])
4040
if pred_t['data']['is_drift']:
4141
detection_times_h0.append(pred_t['data']['time'])
42-
cd.reset()
42+
cd.reset_state()
4343
art = np.array(detection_times_h0).mean() - np.min(window_sizes) + 1
4444
test_stats_h0 = [ts for ts in test_stats_h0 if ts is not None]
4545
assert ert/3 < art < 3*ert
4646

4747
# Drifted data
48-
cd.reset()
48+
cd.reset_state()
4949
detection_times_h1 = []
5050
test_stats_h1 = []
5151
for x_t in x_h1:
5252
pred_t = cd.predict(x_t, return_test_stat=True)
5353
test_stats_h1.append(pred_t['data']['test_stat'])
5454
if pred_t['data']['is_drift']:
5555
detection_times_h1.append(pred_t['data']['time'])
56-
cd.reset()
56+
cd.reset_state()
5757
add = np.array(detection_times_h1).mean() - np.min(window_sizes)
5858
test_stats_h1 = [ts for ts in test_stats_h1 if ts is not None]
5959
assert add < ert/2
@@ -91,7 +91,7 @@ def test_cvm_online_state_online(n_feat, tmp_path, seed):
9191
test_stats_1.append(preds['data']['test_stat'])
9292

9393
# Reset and check state cleared
94-
dd.reset()
94+
dd.reset_state()
9595
for key, orig_val in state_dict_t0.items():
9696
np.testing.assert_array_equal(orig_val, getattr(dd, key)) # use np.testing here as it handles torch.Tensor etc
9797

alibi_detect/cd/tests/test_fet_online.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_fetdriftonline(alternative, n_feat, seed):
4242
assert cd.t - t0 == 1 # This checks state updated (self.t at least)
4343
if pred_t['data']['is_drift']:
4444
detection_times_h0.append(pred_t['data']['time'])
45-
cd.reset()
45+
cd.reset_state()
4646

4747
# Drifted data
4848
if alternative == 'less':
@@ -52,15 +52,15 @@ def test_fetdriftonline(alternative, n_feat, seed):
5252
p_h1 = 0.9
5353
x_h1 = partial(np.random.choice, (0, 1), size=n_feat, p=[1-p_h1, p_h1])
5454

55-
cd.reset()
55+
cd.reset_state()
5656
count = 0
5757
while len(detection_times_h1) < n_reps and count < int(1e6):
5858
count += 1
5959
x_t = x_h1().reshape(1, 1) if n_feat == 1 else x_h1() # test shape (1,1) in 1D case here
6060
pred_t = cd.predict(x_t)
6161
if pred_t['data']['is_drift']:
6262
detection_times_h1.append(pred_t['data']['time'])
63-
cd.reset()
63+
cd.reset_state()
6464

6565
art = np.array(detection_times_h0).mean() - np.min(window_sizes) + 1
6666
add = np.array(detection_times_h1).mean() - np.min(window_sizes)
@@ -100,7 +100,7 @@ def test_fet_online_state_online(n_feat, tmp_path, seed):
100100
test_stats_1.append(preds['data']['test_stat'])
101101

102102
# Reset and check state cleared
103-
dd.reset()
103+
dd.reset_state()
104104
for key, orig_val in state_dict_t0.items():
105105
np.testing.assert_array_equal(orig_val, getattr(dd, key)) # use np.testing here as it handles torch.Tensor etc
106106

doc/source/cd/methods/onlinecvmdrift.ipynb

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,6 @@
115115
"\n",
116116
"```python\n",
117117
"preds = cd.predict(x_t, return_test_stat=True)\n",
118-
"```\n",
119-
"\n",
120-
"Resetting the detector with the same reference data and thresholds but with a new and empty test-window is straight-forward:\n",
121-
"\n",
122-
"```python\n",
123-
"cd.reset()\n",
124118
"```"
125119
]
126120
},
@@ -146,7 +140,7 @@
146140
"cd.load_state('checkpoint_t1')\n",
147141
"```\n",
148142
"\n",
149-
"At any point, the state may be reset with the `reset` method."
143+
"At any point, the state may be reset with the `reset_state` method."
150144
]
151145
}
152146
],
@@ -169,7 +163,7 @@
169163
"name": "python",
170164
"nbconvert_exporter": "python",
171165
"pygments_lexer": "ipython3",
172-
"version": "3.9.5"
166+
"version": "3.8.11"
173167
}
174168
},
175169
"nbformat": 4,

0 commit comments

Comments
 (0)