File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -43,10 +43,16 @@ class AccelerateRLTrainer(BaseRLTrainer):
43
43
RL model trainer with an `accelerate` based backend
44
44
"""
45
45
46
- def __init__ (self , config , ** kwargs ):
46
+ def __init__ (self , config , ** kwargs ): # noqa: C901
47
47
super ().__init__ (config , ** kwargs )
48
48
self .max_length = config .train .seq_length
49
49
self .accelerator = Accelerator (log_with = config .train .tracker , logging_dir = config .train .logging_dir )
50
+
51
+ if self .accelerator .state .deepspeed_plugin is not None :
52
+ # by accelerate's default, arguments in `model.forward` would be casted to half
53
+ if "fp16" in self .accelerator .state .deepspeed_plugin .deepspeed_config :
54
+ self .accelerator .state .deepspeed_plugin .deepspeed_config ["fp16" ]["auto_cast" ] = False
55
+
50
56
if int (os .environ .get ("WORLD_SIZE" , 1 )) > 1 :
51
57
torch .distributed .barrier (device_ids = [int (os .environ .get ("LOCAL_RANK" , 0 ))])
52
58
You can’t perform that action at this time.
0 commit comments