Skip to content

Commit f5a7d78

Browse files
fix(base_trainer): set deepspeed's auto_cast to false for fp16 (#279)
fixes #238
1 parent 6892fc3 commit f5a7d78

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

trlx/trainer/accelerate_base_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,16 @@ class AccelerateRLTrainer(BaseRLTrainer):
4343
RL model trainer with an `accelerate` based backend
4444
"""
4545

46-
def __init__(self, config, **kwargs):
46+
def __init__(self, config, **kwargs): # noqa: C901
4747
super().__init__(config, **kwargs)
4848
self.max_length = config.train.seq_length
4949
self.accelerator = Accelerator(log_with=config.train.tracker, logging_dir=config.train.logging_dir)
50+
51+
if self.accelerator.state.deepspeed_plugin is not None:
52+
# by accelerate's default, arguments in `model.forward` would be casted to half
53+
if "fp16" in self.accelerator.state.deepspeed_plugin.deepspeed_config:
54+
self.accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["auto_cast"] = False
55+
5056
if int(os.environ.get("WORLD_SIZE", 1)) > 1:
5157
torch.distributed.barrier(device_ids=[int(os.environ.get("LOCAL_RANK", 0))])
5258

0 commit comments

Comments
 (0)