Skip to content

Commit 0ed8de7

Browse files
[Recompute] Update recompute for hybrid parallel interface. (#3211)
Co-authored-by: Zhong Hui <[email protected]>
1 parent c6abe76 commit 0ed8de7

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

examples/language_model/gpt-3/dygraph/modeling.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1178,4 +1178,10 @@ def _logits_helper(embedding, output):
11781178
loss_fn=GPTPretrainingCriterionPipe(),
11791179
topology=topology,
11801180
seg_method="layer:TransformerDecoderLayer",
1181-
recompute_interval=1 if use_recompute else 0)
1181+
recompute_interval=1 if use_recompute else 0,
1182+
recompute_ctx={
1183+
"mp_group":
1184+
fleet.fleet._hcg.get_model_parallel_group(),
1185+
"offload": False,
1186+
"partition": False
1187+
})

examples/language_model/moe/dygraph/modeling.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,5 +1165,9 @@ def _logits_helper(embedding, output):
11651165
topology=topology,
11661166
seg_method="layer:TransformerDecoderLayer",
11671167
recompute_interval=recompute_interval,
1168-
recompute_partition=False,
1169-
recompute_offload=False)
1168+
recompute_ctx={
1169+
"mp_group":
1170+
fleet.fleet._hcg.get_model_parallel_group(),
1171+
"offload": False,
1172+
"partition": False
1173+
})

0 commit comments

Comments
 (0)