Skip to content

Commit 9b04a77

Browse files
committed
compute fingerprint before calling the actual func
1 parent 3d97a26 commit 9b04a77

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/nlp/fingerprint.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from copy import deepcopy
23
from dataclasses import asdict
34
from functools import wraps
45
from typing import TYPE_CHECKING
@@ -137,9 +138,12 @@ def wrapper(*args, **kwargs):
137138
if kwargs_for_fingerprint.get("seed") is None and kwargs_for_fingerprint.get("generator") is None:
138139
kwargs_for_fingerprint["generator"] = np.random.default_rng(None)
139140

140-
# add new_fingerprint arg to not in-place transforms
141+
# compute new_fingerprint and add it to the args of not in-place transforms
141142

142-
if not inplace:
143+
if inplace:
144+
new_fingerprint = update_fingerprint(self._fingerprint, func, kwargs_for_fingerprint)
145+
new_inplace_history_item = (func.__name__, deepcopy(args), deepcopy(kwargs))
146+
else:
143147
for fingerprint_name in fingerprint_names: # transforms like `train_test_split` have several hashes
144148
if fingerprint_name not in kwargs:
145149
kwargs_for_fingerprint["fingerprint_name"] = fingerprint_name
@@ -152,9 +156,9 @@ def wrapper(*args, **kwargs):
152156
# Update fingerprint of in-place transforms + update in-place history of transforms
153157

154158
if inplace: # update after calling func so that the fingerprint doesn't change if the function fails
155-
self._fingerprint = update_fingerprint(self._fingerprint, func, kwargs_for_fingerprint)
159+
self._fingerprint = new_fingerprint
156160
for inplace_hist_per_file in self._inplace_history:
157-
inplace_hist_per_file["transforms"].append((func.__name__, args, kwargs))
161+
inplace_hist_per_file["transforms"].append(new_inplace_history_item)
158162

159163
return out
160164

0 commit comments

Comments
 (0)