@@ -1185,7 +1185,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
11851185 )
11861186
11871187 model_path = os .path .dirname (resolved_archive_file )
1188- state_dict = load_tp_checkpoint (model_path , cls , config )
1188+ state_dict = load_tp_checkpoint (model_path , cls , config , return_numpy = True )
11891189 model .set_state_dict (state_dict )
11901190 return model
11911191
@@ -1322,7 +1322,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
13221322 )
13231323
13241324 model_path = os .path .dirname (resolved_archive_file )
1325- state_dict = load_tp_checkpoint (model_path , cls , config )
1325+ state_dict = load_tp_checkpoint (model_path , cls , config , return_numpy = True )
13261326 model .set_state_dict (state_dict )
13271327 return model
13281328
@@ -1599,7 +1599,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
15991599 )
16001600
16011601 model_path = os .path .dirname (resolved_archive_file )
1602- state_dict = load_tp_checkpoint (model_path , cls , config )
1602+ state_dict = load_tp_checkpoint (model_path , cls , config , return_numpy = True )
16031603 model .set_state_dict (state_dict )
16041604
16051605 return model
0 commit comments