Skip to content

Commit dff7dcc

Browse files
init model using float32 (#8033)
1 parent b72f352 commit dff7dcc

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

llm/llama/auto_parallel/run_pretrain_auto.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -542,16 +542,16 @@ def main():
542542

543543
print("Final pre-training config:", config)
544544

545-
# Set the dtype for loading model
546-
dtype = "float32"
547-
if training_args.fp16_opt_level == "O2":
548-
if training_args.fp16:
549-
dtype = "float16"
550-
if training_args.bf16:
551-
dtype = "bfloat16"
545+
# # Set the dtype for loading model
546+
# dtype = "float32"
547+
# if training_args.fp16_opt_level == "O2":
548+
# if training_args.fp16:
549+
# dtype = "float16"
550+
# if training_args.bf16:
551+
# dtype = "bfloat16"
552552

553553
with paddle.LazyGuard():
554-
model = model_class.from_config(config, dtype=dtype)
554+
model = model_class.from_config(config, dtype="float32")
555555
criterion = criterion_class(config)
556556

557557
for param in model.parameters():

0 commit comments

Comments
 (0)