Skip to content

Commit 1afb9b2

Browse files
authored
fix (#9779)
1 parent 9e0b830 commit 1afb9b2

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

paddlenlp/utils/serialization.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,24 @@ def find_class(self, mod_name, name):
160160
return super().find_class(mod_name, name)
161161

162162

163+
class SafeUnpickler(pickle.Unpickler):
164+
"""
165+
A safe unpickler that only allows loading of built-in basic data types.
166+
"""
167+
168+
def find_class(self, module, name):
169+
"""
170+
Overrides the find_class method to only allow loading of built-in basic data types.
171+
172+
:param module: The module name.
173+
:param name: The class name.
174+
:return: The class if allowed, otherwise raises UnpicklingError.
175+
"""
176+
if module == "builtins" and name in {"int", "float", "str", "tuple", "list", "dict", "set"}:
177+
return super().find_class(module, name)
178+
raise pickle.UnpicklingError(f"Unsafe object loading is prohibited: {module}.{name}")
179+
180+
163181
def _rebuild_tensor_stage(storage, storage_offset, size, stride, requires_grad, backward_hooks):
164182
# if a tensor has shape [M, N] and stride is [1, N], it's column-wise / fortran-style
165183
# if a tensor has shape [M, N] and stride is [M, 1], it's row-wise / C-style

slm/model_zoo/chinesebert/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
LinearDecayWithWarmup,
3030
PolyDecayWithWarmup,
3131
)
32+
from paddlenlp.utils.serialization import SafeUnpickler
3233

3334
scheduler_type2cls = {
3435
"linear": LinearDecayWithWarmup,
@@ -121,7 +122,7 @@ def save_pickle(data, file_path):
121122

122123
def load_pickle(input_file):
123124
with open(str(input_file), "rb") as f:
124-
data = pickle.load(f)
125+
data = SafeUnpickler(f).load()
125126
return data
126127

127128

slm/model_zoo/t5/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
LinearDecayWithWarmup,
2929
PolyDecayWithWarmup,
3030
)
31+
from paddlenlp.utils.serialization import SafeUnpickler
3132

3233

3334
def accuracy(targets, predictions):
@@ -158,5 +159,5 @@ def save_pickle(data, file_path):
158159

159160
def load_pickle(input_file):
160161
with open(str(input_file), "rb") as f:
161-
data = pickle.load(f)
162+
data = SafeUnpickler(f).load()
162163
return data

0 commit comments

Comments
 (0)