Skip to content

Commit b506954

Browse files
author
igor
committed
Add pickle_metadata flag
1 parent 5881bb2 commit b506954

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

batchflow/models/torch/base.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,7 +1669,7 @@ def convert_outputs(self, outputs):
16691669

16701670
# Store model
16711671
def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_openvino=None,
1672-
use_safetensors=False, path_safetensors=None,
1672+
use_safetensors=False, path_safetensors=None, pickle_metadata=False,
16731673
batch_size=None, opset_version=13, pickle_module=dill, ignore_attributes=None, **kwargs):
16741674
""" Save underlying PyTorch model along with meta parameters (config, device spec, etc).
16751675
@@ -1730,12 +1730,13 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
17301730
path_onnx = path_onnx or (path + '_onnx')
17311731
torch.onnx.export(self.model.eval(), inputs, path_onnx, opset_version=opset_version)
17321732

1733-
# Save the rest of parameters
1734-
preserved = self.PRESERVE_ONNX - ignore_attributes
1733+
if pickle_metadata:
1734+
# Save the rest of parameters
1735+
preserved = self.PRESERVE_ONNX - ignore_attributes
17351736

1736-
preserved_dict = {item: getattr(self, item) for item in preserved}
1737-
torch.save({'onnx': True, 'path_onnx': path_onnx, 'onnx_batch_size': batch_size, **preserved_dict},
1738-
path, pickle_module=pickle_module, **kwargs)
1737+
preserved_dict = {item: getattr(self, item) for item in preserved}
1738+
torch.save({'onnx': True, 'path_onnx': path_onnx, 'onnx_batch_size': batch_size, **preserved_dict},
1739+
path, pickle_module=pickle_module, **kwargs)
17391740

17401741
elif use_openvino:
17411742
import openvino as ov
@@ -1753,20 +1754,23 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
17531754

17541755
ov.save_model(model, output_model=path_openvino)
17551756

1756-
# Save the rest of parameters
1757-
preserved = self.PRESERVE_OPENVINO - ignore_attributes
1758-
preserved_dict = {item: getattr(self, item) for item in preserved}
1759-
torch.save({'openvino': True, 'path_openvino': path_openvino, **preserved_dict},
1760-
path, pickle_module=pickle_module, **kwargs)
1757+
if pickle_metadata:
1758+
# Save the rest of parameters
1759+
preserved = self.PRESERVE_OPENVINO - ignore_attributes
1760+
preserved_dict = {item: getattr(self, item) for item in preserved}
1761+
torch.save({'openvino': True, 'path_openvino': path_openvino, **preserved_dict},
1762+
path, pickle_module=pickle_module, **kwargs)
17611763

17621764
elif use_safetensors:
17631765
from safetensors.torch import save_file
17641766
state_dict = self.model.state_dict()
17651767

17661768
path_safetensors = path_safetensors or (path + "safetensors")
17671769
save_file(state_dict, path_safetensors)
1768-
torch.save({'safetensors': True, 'path_safetensors': path_safetensors, **preserved_dict},
1769-
path, pickle_module=pickle_module, **kwargs)
1770+
1771+
if pickle_metadata:
1772+
torch.save({'safetensors': True, 'path_safetensors': path_safetensors, **preserved_dict},
1773+
path, pickle_module=pickle_module, **kwargs)
17701774

17711775
else:
17721776
preserved = set(self.PRESERVE) - set(ignore_attributes)

0 commit comments

Comments
 (0)