Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,14 @@ Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并
with 8 cards, then set sharding_degree=8, sharding will only communication inside machine.
default -1 means sharding parameters between all workers. (`int`, *optional*, defaults to `-1`)

--sharding_comm_buffer_size_MB
设置sharding的通信中fuse梯度的大小。此选项只在sharding选项开启时候生效。
默认值为-1,表示所有通信fuse的梯度大小按照默认配置,默认配置是256MB。
(`int`, 可选, 默认为 `-1`)

Set the size of the fuse gradient in sharding communication. This option only takes effect when the sharding option is turned on.The default value is -1, which means that the gradient size of all communication fuses follows the default configuration, which is 256MB.
(`int`, optional, default `-1`)

--tensor_parallel_degree
张量并行是Megatron论文针对Transformer结构的张量切分方法.
此方法将一层transformer的计算划分到了不同卡上.
Expand Down
16 changes: 16 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,17 @@ class TrainingArguments:
)
},
)
sharding_comm_buffer_size_MB: int = field(
default=-1,
metadata={
"help": (
"Set the size of the fuse gradient in sharding communication. This option only takes effect when "
"the sharding option is turned on.The default value is -1, which means that the gradient size of "
"all communication fuses follows the default configuration, which is 256MB. "
)
},
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_sharded_model: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -1293,6 +1304,11 @@ def is_segment_parallel_supported():
)

try:
if self.sharding_comm_buffer_size_MB > 0:
strategy.hybrid_configs["sharding_configs"].comm_buffer_size_MB = int(
self.sharding_comm_buffer_size_MB
)

if "split_param" in sharding_parallel_config:
strategy.hybrid_configs["sharding_configs"].split_param = True

Expand Down