Skip to content

Commit 9e4466b

Browse files
authored
update sharding config (#6457)
* update sharding config * fix * fix typo
1 parent d4ac513 commit 9e4466b

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ class TrainingArguments:
236236
Some additional config it highly affect the useage of sharding parallel, we provide some option to config it.
237237
following config is support:
238238
enable_stage1_tensor_fusion, fuse small tensors into big tensor chunks to accelerate communications, may increase memory occupation
239+
enable_stage1_overlap, fuse small tensors into big tensor chunks to accelerate communications and do communication overlap with backward computation, may harm the backward speed
239240
recompute (`bool`, *optional*, defaults to `False`):
240241
Recompute the forward pass to calculate gradients. Used for saving memory.
241242
Only support for networks with transformer blocks.
@@ -541,7 +542,8 @@ class TrainingArguments:
541542
"help": (
542543
"Some additional config it highly affect the useage of sharding parallel, we provide some option to config it."
543544
"following config is support: \n"
544-
"enable_stage1_tensor_fusion, fuse small tensors into big tensor chunks to accelerate communications, may increase memory occupation"
545+
"enable_stage1_tensor_fusion, fuse small tensors into big tensor chunks to accelerate communications, may increase memory occupation\n"
546+
"enable_stage1_overlap, fuse small tensors into big tensor chunks to accelerate communications and do communication overlap with backward computation, may harm the backward speed"
545547
)
546548
},
547549
)
@@ -852,21 +854,24 @@ def __post_init__(self):
852854
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
853855
for x in sharding_parallel_config:
854856
if len(x) > 0:
855-
if x not in [
856-
"enable_stage1_tensor_fusion",
857-
]:
857+
if x not in ["enable_stage1_tensor_fusion", "enable_stage1_overlap"]:
858858
raise ValueError(
859859
f"Found unknown pipeline mode config {x}, "
860-
f"accpet config is enable_stage1_tensor_fusion."
860+
f"accpet config is enable_stage1_tensor_fusion, enable_stage1_overlap."
861861
)
862862
try:
863863
strategy.hybrid_configs["sharding_configs"].tensor_fusion = (
864864
True if "enable_stage1_tensor_fusion" in sharding_parallel_config else False
865865
)
866+
if "enable_stage1_overlap" in sharding_parallel_config:
867+
strategy.hybrid_configs["sharding_configs"].comm_overlap = True
868+
strategy.hybrid_configs[
869+
"sharding_configs"
870+
].accumulate_steps = self.gradient_accumulation_steps
866871
except KeyError:
867872
warnings.warn(
868-
"The enable_stage1_tensor_fusion is not supported by current version of Paddle. "
869-
"Please try lateset develop Paddle."
873+
"The enable_stage1_tensor_fusion or enable_stage1_overlap is not supported "
874+
"by current version of Paddle. Please try latest develop Paddle."
870875
)
871876
fleet.init(is_collective=True, strategy=strategy)
872877

0 commit comments

Comments
 (0)