Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions legacy/examples/RLHF/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,11 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs
fused_allreduce_gradients(list(model.parameters()), None)

# Pipeline parallel mode, handle gradient reduce here to overlap
pipeline_parallel_config = (
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
)
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
enable_dp_comm_overlap = False
enable_release_grads = False
if args.pipeline_parallel_degree > 1:
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config

# Case 3: Pipeline parallel mode, overlap with dp
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:
Expand Down
10 changes: 5 additions & 5 deletions llm/alignment/ppo/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,11 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs
fused_allreduce_gradients(list(model.parameters()), None)

# Pipeline parallel mode, handle gradient reduce here to overlap
pipeline_parallel_config = (
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
)
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
enable_dp_comm_overlap = False
enable_release_grads = False
if args.pipeline_parallel_degree > 1:
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config

# Case 3: Pipeline parallel mode, overlap with dp
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:
Expand Down
25 changes: 10 additions & 15 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,17 +1083,13 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

# Pipeline parallel mode, handle gradient reduce here to overlap
pipeline_parallel_config = (
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
)
sharding_parallel_config = (
set(args.sharding_parallel_config.split(" ")) if args.sharding_parallel_degree > 1 else set()
)
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
enable_release_grads = (
"enable_release_grads" in pipeline_parallel_config
or "enable_release_grads" in sharding_parallel_config
)
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config

enable_release_grads = False
if args.sharding_parallel_degree > 1:
enable_release_grads = "enable_release_grads" in args.sharding_parallel_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方判断咋感觉不太对呢?原来的是 enable_release_grads = (... or ...),现在的写法相当于只由 args.pipeline_parallel_config 来决定了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,感谢指正

if not enable_release_grads and args.pipeline_parallel_degree > 1:
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config

# Case 3: Pipeline parallel mode, overlap with dp
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:
Expand Down Expand Up @@ -1992,8 +1988,7 @@ def get_expected_keys(inputs, keys):
"please upgrade your paddle (using nightly version)."
)

sharding_parallel_config = set(self.args.sharding_parallel_config.split(" "))
if level == "os_g" and "enable_stage2_overlap" in sharding_parallel_config:
if level == "os_g" and "enable_stage2_overlap" in self.args.sharding_parallel_config:
model._set_reduce_overlap(True)
optimizer._set_broadcast_overlap(True, model)

Expand Down Expand Up @@ -2133,9 +2128,9 @@ def compute_loss(self, model, inputs, return_outputs=False):
def _enable_delay_scale_loss(self):
key = "enable_delay_scale_loss"
if self.args.pipeline_parallel_degree > 1:
return key in self.args.pipeline_parallel_config.split(" ")
return key in self.args.pipeline_parallel_config
elif self.args.tensor_parallel_degree > 1:
return key in self.args.tensor_parallel_config.split(" ")
return key in self.args.tensor_parallel_config
else:
return False

Expand Down
39 changes: 14 additions & 25 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,13 @@ def __post_init__(self):
logger.warning("set amp_master_grad to false since amp is disabled.")
self.amp_master_grad = False

def split_parallel_config(parallel_config):
if "," in parallel_config:
parallel_config = set(parallel_config.split(","))
else:
parallel_config = set(parallel_config.split(" "))
return parallel_config

# use_hybrid_parallel
if self.use_hybrid_parallel:

Expand All @@ -1039,10 +1046,7 @@ def __post_init__(self):
strategy = fleet.DistributedStrategy()
assert self.data_parallel_config == "", "data_parallle_config is not supported in hybrid parallel"
if self.pipeline_parallel_degree > 1:
if " " in self.pipeline_parallel_config:
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
else:
pipeline_parallel_config = set(self.pipeline_parallel_config.split(","))
pipeline_parallel_config = split_parallel_config(self.pipeline_parallel_config)
for x in pipeline_parallel_config:
if len(x) > 0:
if x not in [
Expand Down Expand Up @@ -1116,10 +1120,7 @@ def __post_init__(self):
if self.tensor_parallel_degree > 1:
strategy.tensor_parallel_configs = {"tensor_init_seed": self.seed}

if " " in self.tensor_parallel_config:
mp_config = set(self.tensor_parallel_config.split(" "))
else:
mp_config = set(self.tensor_parallel_config.split(","))
mp_config = split_parallel_config(self.tensor_parallel_config)

for x in mp_config:
if len(x) > 0:
Expand Down Expand Up @@ -1225,10 +1226,8 @@ def is_segment_parallel_supported():
strategy.hybrid_configs = hybrid_configs

if self.sharding_parallel_degree > 1:
if " " in self.sharding_parallel_config:
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
else:
sharding_parallel_config = set(self.sharding_parallel_config.split(","))
sharding_parallel_config = split_parallel_config(self.sharding_parallel_config)

for x in sharding_parallel_config:
if len(x) > 0:
if x not in [
Expand Down Expand Up @@ -1384,10 +1383,7 @@ def is_segment_parallel_supported():

# navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1
if self.pipeline_parallel_degree > 1 and self.gradient_accumulation_steps > 1:
if " " in self.pipeline_parallel_config:
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
else:
pipeline_parallel_config = set(self.pipeline_parallel_config.split(","))
pipeline_parallel_config = split_parallel_config(self.pipeline_parallel_config)
for x in pipeline_parallel_config:
if len(x) > 0:
if x not in [
Expand Down Expand Up @@ -1436,11 +1432,7 @@ def is_segment_parallel_supported():

if self.tensor_parallel_degree > 1:
mp_optimization = strategy.mp_optimization

if " " in self.tensor_parallel_config:
mp_config = set(self.tensor_parallel_config.split(" "))
else:
mp_config = set(self.tensor_parallel_config.split(","))
mp_config = split_parallel_config(self.tensor_parallel_config)

for x in mp_config:
if len(x) > 0:
Expand Down Expand Up @@ -1473,10 +1465,7 @@ def is_segment_parallel_supported():
elif ShardingOption.FULL_SHARD in self.sharding:
sharding.stage = 3

if " " in self.sharding_parallel_config:
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
else:
sharding_parallel_config = set(self.sharding_parallel_config.split(","))
sharding_parallel_config = split_parallel_config(self.sharding_parallel_config)
for x in sharding_parallel_config:
if len(x) > 0:
if x not in [
Expand Down