File tree Expand file tree Collapse file tree 2 files changed +13
-3
lines changed Expand file tree Collapse file tree 2 files changed +13
-3
lines changed Original file line number Diff line number Diff line change @@ -1178,4 +1178,10 @@ def _logits_helper(embedding, output):
11781178 loss_fn = GPTPretrainingCriterionPipe (),
11791179 topology = topology ,
11801180 seg_method = "layer:TransformerDecoderLayer" ,
1181- recompute_interval = 1 if use_recompute else 0 )
1181+ recompute_interval = 1 if use_recompute else 0 ,
1182+ recompute_ctx = {
1183+ "mp_group" :
1184+ fleet .fleet ._hcg .get_model_parallel_group (),
1185+ "offload" : False ,
1186+ "partition" : False
1187+ })
Original file line number Diff line number Diff line change @@ -1165,5 +1165,9 @@ def _logits_helper(embedding, output):
11651165 topology = topology ,
11661166 seg_method = "layer:TransformerDecoderLayer" ,
11671167 recompute_interval = recompute_interval ,
1168- recompute_partition = False ,
1169- recompute_offload = False )
1168+ recompute_ctx = {
1169+ "mp_group" :
1170+ fleet .fleet ._hcg .get_model_parallel_group (),
1171+ "offload" : False ,
1172+ "partition" : False
1173+ })
You can’t perform that action at this time.
0 commit comments