Skip to content

Commit 31b8a46

Browse files
fix(nn/models): initializing base transformer from a custom config (#149)
1 parent af242e1 commit 31b8a46

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

trlx/trainer/nn/ilql_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,14 @@ def __init__(
197197
_hfconfig = transformers.deepspeed.HfDeepSpeedConfig( # noqa: F841
198198
config_path
199199
)
200+
200201
if isinstance(config, str):
201202
self.config = transformers.AutoConfig.from_pretrained(config)
203+
self.base_model = transformers.AutoModelForCausalLM.from_pretrained(config)
202204
else:
203205
self.config = config
206+
self.base_model = transformers.AutoModelForCausalLM.from_config(config)
204207

205-
self.base_model = transformers.AutoModelForCausalLM.from_pretrained(
206-
self.config.name_or_path,
207-
)
208208
self.base_model.transformer = hf_get_causal_base_model(self.base_model)
209209
self.base_model.lm_head = hf_get_lm_head(self.base_model)
210210
freeze_bottom_causal_layers(self.base_model, num_layers_unfrozen)

trlx/trainer/nn/ppo_models.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,11 @@ def __init__(self, config: Union[transformers.PretrainedConfig, str]):
232232
super().__init__()
233233
if isinstance(config, str):
234234
self.config = transformers.AutoConfig.from_pretrained(config)
235+
self.base_model = transformers.AutoModelForCausalLM.from_pretrained(config)
235236
else:
236237
self.config = config
237-
self.base_model = transformers.AutoModelForCausalLM.from_pretrained(
238-
self.config.name_or_path
239-
)
238+
self.base_model = transformers.AutoModelForCausalLM.from_config(config)
239+
240240
self.base_model.transformer = hf_get_causal_base_model(self.base_model)
241241
self.base_model.lm_head = hf_get_lm_head(self.base_model)
242242
self.v_head = make_head(hf_get_hidden_size(self.config), 1)
@@ -304,13 +304,14 @@ def __init__(
304304
num_layers_unfrozen: int = -1,
305305
):
306306
super().__init__()
307+
307308
if isinstance(config, str):
308309
self.config = transformers.AutoConfig.from_pretrained(config)
310+
self.base_model = transformers.AutoModelForCausalLM.from_pretrained(config)
309311
else:
310312
self.config = config
311-
self.base_model = transformers.AutoModelForCausalLM.from_pretrained(
312-
self.config.name_or_path
313-
)
313+
self.base_model = transformers.AutoModelForCausalLM.from_config(config)
314+
314315
self.base_model.transformer = hf_get_causal_base_model(self.base_model)
315316
self.base_model.lm_head = hf_get_lm_head(self.base_model)
316317
self.v_head = make_head(hf_get_hidden_size(self.config), 1)

0 commit comments

Comments
 (0)