Skip to content

Return the OOF instances for classifier-based drift detectors (ClassifierDrift and SpotTheDiffDrift) #665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions alibi_detect/cd/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def test_probs(
return p_val, dist

@abstractmethod
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, np.ndarray]:
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
pass

def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
Expand All @@ -260,7 +260,8 @@ def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
K-S test stat if binarize_preds=False, otherwise relative error reduction.
return_probs
Whether to return the instance level classifier probabilities for the reference and test data
(0=reference data, 1=test data).
(0=reference data, 1=test data). The reference and test instances of the associated
probabilities are also returned.
return_model
Whether to return the updated model trained to discriminate reference and test instances.

Expand All @@ -270,10 +271,11 @@ def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
'meta' has the model's metadata.
'data' contains the drift prediction and optionally the p-value, performance of the classifier
relative to its expectation under the no-change null, the out-of-fold classifier model
prediction probabilities on the reference and test data, and the trained model.
prediction probabilities on the reference and test data as well as well as the associated reference
and test instances of the out-of-fold predictions, and the trained model.
"""
# compute drift scores
p_val, dist, probs_ref, probs_test = self.score(x)
p_val, dist, probs_ref, probs_test, x_ref_oof, x_test_oof = self.score(x)
drift_pred = int(p_val < self.p_val)

# update reference dataset
Expand All @@ -297,6 +299,8 @@ def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True,
if return_probs:
cd['data']['probs_ref'] = probs_ref
cd['data']['probs_test'] = probs_test
cd['data']['x_ref_oof'] = x_ref_oof
cd['data']['x_test_oof'] = x_test_oof
if return_model:
cd['data']['model'] = self.model
return cd
Expand Down
10 changes: 6 additions & 4 deletions alibi_detect/cd/pytorch/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(
if isinstance(train_kwargs, dict):
self.train_kwargs.update(train_kwargs)

def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, np.ndarray]:
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Compute the out-of-fold drift metric such as the accuracy from a classifier
trained to distinguish the reference data from the data to be tested.
Expand All @@ -171,7 +171,8 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, n
-------
p-value, a notion of distance between the trained classifier's out-of-fold performance \
and that which we'd expect under the null assumption of no drift, \
and the out-of-fold classifier model prediction probabilities on the reference and test data
and the out-of-fold classifier model prediction probabilities on the reference and test data \
as well as the associated reference and test instances of the out-of-fold predictions.
"""
x_ref, x = self.preprocess(x)
x, y, splits = self.get_splits(x_ref, x) # type: ignore
Expand All @@ -198,9 +199,10 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, n
preds_oof = np.concatenate(preds_oof_list, axis=0)
probs_oof = softmax(preds_oof, axis=-1) if self.preds_type == 'logits' else preds_oof
idx_oof = np.concatenate(idx_oof_list, axis=0)
y_oof = y[idx_oof]
x_oof, y_oof = x[idx_oof], y[idx_oof]
n_cur = y_oof.sum()
n_ref = len(y_oof) - n_cur
p_val, dist = self.test_probs(y_oof, probs_oof, n_ref, n_cur)
probs_sort = probs_oof[np.argsort(idx_oof)]
return p_val, dist, probs_sort[:n_ref, 1], probs_sort[n_ref:, 1]
x_sort = x_oof[np.argsort(idx_oof)]
return p_val, dist, probs_sort[:n_ref, 1], probs_sort[n_ref:, 1], x_sort[:n_ref], x_sort[n_ref:]
3 changes: 2 additions & 1 deletion alibi_detect/cd/pytorch/spot_the_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def predict(
'data' contains the drift prediction, the diffs used to distinguish reference from test instances,
and optionally the p-value, performance of the classifier relative to its expectation under the
no-change null, the out-of-fold classifier model prediction probabilities on the reference and test
data, and the trained model.
data as well as well as the associated reference and test instances of the out-of-fold predictions,
and the trained model.
"""
preds = self._detector.predict(x, return_p_val, return_distance, return_probs, return_model=True)
preds['data']['diffs'] = preds['data']['model'].diffs.detach().cpu().numpy() # type: ignore
Expand Down
28 changes: 16 additions & 12 deletions alibi_detect/cd/sklearn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def predict_proba(self, X):

return model

def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, np.ndarray]:
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Compute the out-of-fold drift metric such as the accuracy from a classifier
trained to distinguish the reference data from the data to be tested.
Expand All @@ -243,14 +243,16 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, n
-------
p-value, a notion of distance between the trained classifier's out-of-fold performance \
and that which we'd expect under the null assumption of no drift, \
and the out-of-fold classifier model prediction probabilities on the reference and test data
and the out-of-fold classifier model prediction probabilities on the reference and test data \
as well as the associated reference and test instances of the out-of-fold predictions.
"""
if self.use_oob and isinstance(self.model, RandomForestClassifier):
return self._score_rf(x)

return self._score(x)

def _score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, np.ndarray]:
def _score(self, x: Union[np.ndarray, list]) \
-> Tuple[float, float, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
x_ref, x = self.preprocess(x)
x, y, splits = self.get_splits(x_ref, x, return_splits=True) # type: ignore

Expand All @@ -270,24 +272,26 @@ def _score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray,
idx_oof_list.append(idx_te)
probs_oof = np.concatenate(probs_oof_list, axis=0)
idx_oof = np.concatenate(idx_oof_list, axis=0)
y_oof = y[idx_oof]
x_oof, y_oof = x[idx_oof], y[idx_oof]
n_cur = y_oof.sum()
n_ref = len(y_oof) - n_cur
p_val, dist = self.test_probs(y_oof, probs_oof, n_ref, n_cur)
probs_sort = probs_oof[np.argsort(idx_oof)]
return p_val, dist, probs_sort[:n_ref, 1], probs_sort[n_ref:, 1]
x_sort = x_oof[np.argsort(idx_oof)]
return p_val, dist, probs_sort[:n_ref, 1], probs_sort[n_ref:, 1], x_sort[:n_ref], x_sort[n_ref:]

def _score_rf(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray, np.ndarray]:
def _score_rf(self, x: Union[np.ndarray, list]) \
-> Tuple[float, float, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
x_ref, x = self.preprocess(x)
x, y = self.get_splits(x_ref, x, return_splits=False) # type: ignore
self.model.fit(x, y)
# it is possible that some inputs do not have OOB scores. This is probably means
# that too few trees were used to compute any reliable estimates.
index_oob = np.where(np.all(~np.isnan(self.model.oob_decision_function_), axis=1))[0]
probs_oob = self.model.oob_decision_function_[index_oob]
y_oob = y[index_oob]
idx_oob = np.where(np.all(~np.isnan(self.model.oob_decision_function_), axis=1))[0]
probs_oob = self.model.oob_decision_function_[idx_oob]
x_oob, y_oob = x[idx_oob], y[idx_oob]
# comparison due to ordering in get_split (i.e, x = [x_ref, x])
n_ref = np.sum(index_oob < len(x_ref)).item()
n_cur = np.sum(index_oob >= len(x_ref)).item()
n_ref = np.sum(idx_oob < len(x_ref)).item()
n_cur = np.sum(idx_oob >= len(x_ref)).item()
p_val, dist = self.test_probs(y_oob, probs_oob, n_ref, n_cur)
return p_val, dist, probs_oob[:n_ref, 1], probs_oob[n_ref:, 1]
return p_val, dist, probs_oob[:n_ref, 1], probs_oob[n_ref:, 1], x_oob[:n_ref], x_oob[n_ref:]
3 changes: 2 additions & 1 deletion alibi_detect/cd/spot_the_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def predict(
'data' contains the drift prediction, the diffs used to distinguish reference from test instances,
and optionally the p-value, performance of the classifier relative to its expectation under the
no-change null, the out-of-fold classifier model prediction probabilities on the reference and test
data, and the trained model.
data as well as well as the associated reference and test instances of the out-of-fold predictions,
and the trained model.
"""
return self._detector.predict(x, return_p_val, return_distance, return_probs, return_model)
11 changes: 7 additions & 4 deletions alibi_detect/cd/tensorflow/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def __init__(
if isinstance(train_kwargs, dict):
self.train_kwargs.update(train_kwargs)

def score(self, x: np.ndarray) -> Tuple[float, float, np.ndarray, np.ndarray]: # type: ignore[override]
def score(self, x: np.ndarray) \
-> Tuple[float, float, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: # type: ignore[override]
"""
Compute the out-of-fold drift metric such as the accuracy from a classifier
trained to distinguish the reference data from the data to be tested.
Expand All @@ -158,7 +159,8 @@ def score(self, x: np.ndarray) -> Tuple[float, float, np.ndarray, np.ndarray]:
-------
p-value, a notion of distance between the trained classifier's out-of-fold performance \
and that which we'd expect under the null assumption of no drift, \
and the out-of-fold classifier model prediction probabilities on the reference and test data
and the out-of-fold classifier model prediction probabilities on the reference and test data \
as well as the associated reference and test instances of the out-of-fold predictions.
"""
x_ref, x = self.preprocess(x) # type: ignore[assignment]
x, y, splits = self.get_splits(x_ref, x) # type: ignore
Expand All @@ -185,9 +187,10 @@ def score(self, x: np.ndarray) -> Tuple[float, float, np.ndarray, np.ndarray]:
preds_oof = np.concatenate(preds_oof_list, axis=0)
probs_oof = softmax(preds_oof, axis=-1) if self.preds_type == 'logits' else preds_oof
idx_oof = np.concatenate(idx_oof_list, axis=0)
y_oof = y[idx_oof]
x_oof, y_oof = x[idx_oof], y[idx_oof]
n_cur = y_oof.sum()
n_ref = len(y_oof) - n_cur
p_val, dist = self.test_probs(y_oof, probs_oof, n_ref, n_cur)
probs_sort = probs_oof[np.argsort(idx_oof)]
return p_val, dist, probs_sort[:n_ref, 1], probs_sort[n_ref:, 1]
x_sort = x_oof[np.argsort(idx_oof)]
return p_val, dist, probs_sort[:n_ref, 1], probs_sort[n_ref:, 1], x_sort[:n_ref], x_sort[n_ref:]
3 changes: 2 additions & 1 deletion alibi_detect/cd/tensorflow/spot_the_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def predict(
'data' contains the drift prediction, the diffs used to distinguish reference from test instances,
and optionally the p-value, performance of the classifier relative to its expectation under the
no-change null, the out-of-fold classifier model prediction probabilities on the reference and test
data, and the trained model.
data as well as well as the associated reference and test instances of the out-of-fold predictions,
and the trained model.
"""
preds = self._detector.predict(x, return_p_val, return_distance, return_probs, return_model=True)
preds['data']['diffs'] = preds['data']['model'].diffs.numpy() # type: ignore
Expand Down
5 changes: 4 additions & 1 deletion doc/source/cd/methods/classifierdrift.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
"source": [
"### Detect Drift\n",
"\n",
"We detect data drift by simply calling `predict` on a batch of instances `x`. `return_p_val` equal to *True* will also return the p-value of the test, `return_distance` equal to *True* will return a notion of strength of the drift and `return_probs` equals *True* also returns the out-of-fold classifier model prediction probabilities on the reference and test data (0 = reference data, 1 = test data).\n",
"We detect data drift by simply calling `predict` on a batch of instances `x`. `return_p_val` equal to *True* will also return the p-value of the test, `return_distance` equal to *True* will return a notion of strength of the drift and `return_probs` equals *True* also returns the out-of-fold classifier model prediction probabilities on the reference and test data (0 = reference data, 1 = test data) as well as the associated out-of-fold reference and test instances.\n",
"\n",
"The prediction takes the form of a dictionary with `meta` and `data` keys. `meta` contains the detector's metadata while `data` is also a dictionary which contains the actual predictions stored in the following keys:\n",
"\n",
Expand All @@ -155,6 +155,9 @@
"\n",
"* `probs_test`: the instance level prediction probability for the test data `x` if `return_probs` is *true*.\n",
"\n",
"* `x_ref_oof`: the instances associated with `probs_ref` if `return_probs` equals *True*.\n",
"\n",
"* `x_test_oof`: the instances associated with `probs_test` if `return_probs` equals *True*.\n",
"\n",
"```python\n",
"preds = cd.predict(x)\n",
Expand Down
38 changes: 21 additions & 17 deletions doc/source/cd/methods/spotthediffdrift.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"* `initial_diffs`: Array used to initialise the diffs that will be learned. Defaults to Gaussian for each feature with equal variance to that of reference data.\n",
"\n",
"* `l1_reg`: Strength of l1 regularisation to apply to the differences.\n",
" \n",
"\n",
"* `binarize_preds`: Whether to test for discrepency on soft (e.g. probs/logits) model predictions directly with a K-S test or binarise to 0-1 prediction errors and apply a binomial test.\n",
"\n",
"* `train_size`: Optional fraction (float between 0 and 1) of the dataset used to train the classifier. The drift is detected on *1 - train_size*. Cannot be used in combination with `n_folds`.\n",
Expand Down Expand Up @@ -109,12 +109,12 @@
"from alibi_detect.cd import SpotTheDiffDrift\n",
"\n",
"cd = SpotTheDiffDrift(\n",
" x_ref, \n",
" backend='pytorch', \n",
" p_val=.05, \n",
" n_diffs=1, \n",
" l1_reg=1e-3, \n",
" epochs=10, \n",
" x_ref,\n",
" backend='pytorch',\n",
" p_val=.05,\n",
" n_diffs=1,\n",
" l1_reg=1e-3,\n",
" epochs=10,\n",
" batch_size=32\n",
")\n",
"\n",
Expand Down Expand Up @@ -143,13 +143,13 @@
"\n",
"# instantiate the detector\n",
"cd = SpotTheDiffDrift(\n",
" x_ref, \n",
" backend='tensorflow', \n",
" p_val=.05, \n",
" kernel=kernel, \n",
" n_diffs=1, \n",
" l1_reg=1e-3, \n",
" epochs=10, \n",
" x_ref,\n",
" backend='tensorflow',\n",
" p_val=.05,\n",
" kernel=kernel,\n",
" n_diffs=1,\n",
" l1_reg=1e-3,\n",
" epochs=10,\n",
" batch_size=32\n",
")\n",
"```"
Expand All @@ -161,7 +161,7 @@
"source": [
"### Detect Drift\n",
"\n",
"We detect data drift by simply calling `predict` on a batch of instances `x`. `return_p_val` equal to *True* will also return the p-value of the test, `return_distance` equal to *True* will return a notion of strength of the drift, `return_probs` equals *True* returns the out-of-fold classifier model prediction probabilities on the reference and test data (0 = reference data, 1 = test data) and `return_kernel` equals *True* will also return the trained kernel.\n",
"We detect data drift by simply calling `predict` on a batch of instances `x`. `return_p_val` equal to *True* will also return the p-value of the test, `return_distance` equal to *True* will return a notion of strength of the drift, `return_probs` equals *True* returns the out-of-fold classifier model prediction probabilities on the reference and test data (0 = reference data, 1 = test data) as well as the associated out-of-fold reference and test instances, and `return_kernel` equals *True* will also return the trained kernel.\n",
"\n",
"The prediction takes the form of a dictionary with `meta` and `data` keys. `meta` contains the detector's metadata while `data` is also a dictionary which contains the actual predictions stored in the following keys:\n",
"\n",
Expand All @@ -181,6 +181,10 @@
"\n",
"* `probs_test`: the instance level prediction probability for the test data `x` if `return_probs` is *true*.\n",
"\n",
"* `x_ref_oof`: the instances associated with `probs_ref` if `return_probs` equals *True*.\n",
"\n",
"* `x_test_oof`: the instances associated with `probs_test` if `return_probs` equals *True*.\n",
"\n",
"* `kernel`: The trained kernel if `return_kernel` equals *True*.\n",
"\n",
"\n",
Expand All @@ -201,7 +205,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -215,7 +219,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
"version": "3.7.6"
}
},
"nbformat": 4,
Expand Down