File tree Expand file tree Collapse file tree 2 files changed +10
-3
lines changed Expand file tree Collapse file tree 2 files changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -344,8 +344,8 @@ def sanitize(self, weights):
344344
345345 return sanitized_weights
346346
347- def load_weights (self , weights ):
348- self .model .load_weights (weights )
347+ def load_weights (self , weights , strict : bool = True ):
348+ self .model .load_weights (weights , strict = strict )
349349
350350 def generate (
351351 self ,
Original file line number Diff line number Diff line change 11import glob
22import importlib
3+ import inspect
34import logging
45import shutil
56from pathlib import Path
@@ -191,7 +192,13 @@ def get_class_predicate(p, m):
191192 class_predicate = get_class_predicate ,
192193 )
193194
194- model .load_weights (list (weights .items ()), strict = strict )
195+ load_weights_sig = inspect .signature (model .load_weights )
196+
197+ kwargs = {"weights" : list (weights .items ())}
198+ if "strict" in load_weights_sig .parameters :
199+ kwargs ["strict" ] = strict
200+
201+ model .load_weights (** kwargs )
195202
196203 if not lazy :
197204 mx .eval (model .parameters ())
You can’t perform that action at this time.
0 commit comments