Skip to content

Commit 4b200f9

Browse files
authored
[Distributed] support fuse optimizer (#9519)
1 parent dd43d5d commit 4b200f9

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ class TrainingArguments:
272272
enable_stage1_allgather_overlap, overlap stage1 V2 allgather with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for allgather overlap forward compute and no other sync could be called during the training for allgather overlap.
273273
disable_stage1_reduce_avg, replace reduce_avg with original reduce_sum+scale in stage1, which can be used for accuracy verification.
274274
enable_release_graHEADds, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration.
275+
enable_fuse_optimizer_states, fuse optimizer states to a single storage.
275276
recompute (`bool`, *optional*, defaults to `False`):
276277
Recompute the forward pass to calculate gradients. Used for saving memory.
277278
Only support for networks with transformer blocks.
@@ -1288,10 +1289,11 @@ def is_segment_parallel_supported():
12881289
"enable_stage1_broadcast_overlap",
12891290
"enable_stage1_allgather_overlap",
12901291
"enable_release_grads",
1292+
"enable_fuse_optimizer_states",
12911293
]:
12921294
raise ValueError(
1293-
f"Found unknown pipeline mode config {x}, "
1294-
f"accpet config is enable_stage1_tensor_fusion, enable_stage1_overlap, enable_stage2_overlap, split_param, disable_stage1_reduce_avg, enable_stage1_broadcast_overlap, enable_stage1_allgather_overlap."
1295+
f"Found unknown sharding mode config {x}, "
1296+
f"accpet config is enable_stage1_tensor_fusion, enable_stage1_overlap, enable_stage2_overlap, split_param, disable_stage1_reduce_avg, enable_stage1_broadcast_overlap, enable_stage1_allgather_overlap, enable_release_grads, enable_fuse_optimizer_states."
12951297
)
12961298
if "disable_stage1_reduce_avg" in sharding_parallel_config:
12971299
assert self.sharding == [
@@ -1316,6 +1318,9 @@ def is_segment_parallel_supported():
13161318
if "enable_release_grads" in sharding_parallel_config:
13171319
strategy.hybrid_configs["sharding_configs"].release_gradients = True
13181320

1321+
if "enable_fuse_optimizer_states" in sharding_parallel_config:
1322+
strategy.hybrid_configs["sharding_configs"].enable_fuse_optimizer_states = True
1323+
13191324
if self.pipeline_parallel_degree == 1:
13201325
strategy.hybrid_configs["sharding_configs"].tensor_fusion = (
13211326
True if "enable_stage1_tensor_fusion" in sharding_parallel_config else False

0 commit comments

Comments
 (0)