@@ -232,11 +232,11 @@ def __init__(self, config: Union[transformers.PretrainedConfig, str]):
232
232
super ().__init__ ()
233
233
if isinstance (config , str ):
234
234
self .config = transformers .AutoConfig .from_pretrained (config )
235
+ self .base_model = transformers .AutoModelForCausalLM .from_pretrained (config )
235
236
else :
236
237
self .config = config
237
- self .base_model = transformers .AutoModelForCausalLM .from_pretrained (
238
- self .config .name_or_path
239
- )
238
+ self .base_model = transformers .AutoModelForCausalLM .from_config (config )
239
+
240
240
self .base_model .transformer = hf_get_causal_base_model (self .base_model )
241
241
self .base_model .lm_head = hf_get_lm_head (self .base_model )
242
242
self .v_head = make_head (hf_get_hidden_size (self .config ), 1 )
@@ -304,13 +304,14 @@ def __init__(
304
304
num_layers_unfrozen : int = - 1 ,
305
305
):
306
306
super ().__init__ ()
307
+
307
308
if isinstance (config , str ):
308
309
self .config = transformers .AutoConfig .from_pretrained (config )
310
+ self .base_model = transformers .AutoModelForCausalLM .from_pretrained (config )
309
311
else :
310
312
self .config = config
311
- self .base_model = transformers .AutoModelForCausalLM .from_pretrained (
312
- self .config .name_or_path
313
- )
313
+ self .base_model = transformers .AutoModelForCausalLM .from_config (config )
314
+
314
315
self .base_model .transformer = hf_get_causal_base_model (self .base_model )
315
316
self .base_model .lm_head = hf_get_lm_head (self .base_model )
316
317
self .v_head = make_head (hf_get_hidden_size (self .config ), 1 )
0 commit comments