Skip to content

Commit f6426bd

Browse files
author
igor
committed
Add logic for onnx weights
1 parent ac8dfb8 commit f6426bd

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

batchflow/models/torch/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,6 +1812,21 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill,
18121812
self.set_model_mode(mode)
18131813

18141814
return
1815+
elif isinstance(file, str) and file.endswith(".onnx"):
1816+
try:
1817+
from onnx2torch import convert
1818+
except ImportError as e:
1819+
raise ImportError('Loading model, stored in ONNX format, requires `onnx2torch` library.') from e
1820+
1821+
model = convert(file).eval()
1822+
self.model = model
1823+
1824+
self.model_to_device()
1825+
1826+
if make_infrastructure:
1827+
self.make_infrastructure()
1828+
1829+
self.set_model_mode(mode)
18151830

18161831
kwargs['map_location'] = self.device
18171832

0 commit comments

Comments
 (0)