Skip to content

Commit b504a73

Browse files
authored
Support hybrid_parallel_topo_order for auto parallel Llama (#8011)
* Support hybrid_parallel_topo_order for auto parallel Llama * Set order in hybrid_configs * Update get_mesh_with_dim * Update loss for CI baseline * Fix CI errors * Update loss * Update loss * Update loss * Update loss * Update loss
1 parent 93aa4bc commit b504a73

File tree

6 files changed

+78
-45
lines changed

6 files changed

+78
-45
lines changed

llm/llama/auto_parallel/run_pretrain_auto.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,13 +404,18 @@ def init_seed(seed: int = 1234, args=None):
404404
else:
405405
assert not args.use_hybrid_parallel and args.enable_auto_parallel
406406
if dist.get_world_size() > 1:
407+
if args.hybrid_parallel_topo_order is None or args.hybrid_parallel_topo_order == "pp_first":
408+
order = ["pp", "dp", "sharding", "mp", "sep"]
409+
elif args.hybrid_parallel_topo_order == "sharding_first":
410+
order = ["dp", "sharding", "pp", "mp", "sep"]
407411
topo = Topology(
408412
dist.get_rank(),
409413
dist.get_world_size(),
410414
dp_degree=args.data_parallel_degree,
411415
pp_degree=args.pipeline_parallel_degree,
412416
mp_degree=args.tensor_parallel_degree,
413417
sharding_degree=1, # auto_parallel's sharding is not orthogonal with dp, mp and pp
418+
order=order,
414419
)
415420

416421
global_seed, local_seed, random_seed = _get_distributed_seeds(args.seed, topo)

llm/llama/auto_parallel/run_pretrain_auto_static.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,13 +414,18 @@ def init_seed(seed: int = 1234, args=None):
414414
else:
415415
assert not args.use_hybrid_parallel and args.enable_auto_parallel
416416
if dist.get_world_size() > 1:
417+
if args.hybrid_parallel_topo_order is None or args.hybrid_parallel_topo_order == "pp_first":
418+
order = ["pp", "dp", "sharding", "mp", "sep"]
419+
elif args.hybrid_parallel_topo_order == "sharding_first":
420+
order = ["dp", "sharding", "pp", "mp", "sep"]
417421
topo = Topology(
418422
dist.get_rank(),
419423
dist.get_world_size(),
420424
dp_degree=args.data_parallel_degree,
421425
pp_degree=args.pipeline_parallel_degree,
422426
mp_degree=args.tensor_parallel_degree,
423427
sharding_degree=1, # auto_parallel's sharding is not orthogonal with dp, mp and pp
428+
order=order,
424429
)
425430

426431
global_seed, local_seed, random_seed = _get_distributed_seeds(args.seed, topo)

paddlenlp/ops/distributed/utils/topo.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,42 +21,55 @@
2121

2222
class Topology:
2323
def __init__(
24-
self, device_rank, world_size, dp_degree=None, pp_degree=1, sharding_degree=1, mp_degree=1, sep_degree=1
24+
self,
25+
device_rank,
26+
world_size,
27+
dp_degree=None,
28+
pp_degree=1,
29+
sharding_degree=1,
30+
mp_degree=1,
31+
sep_degree=1,
32+
order=["dp", "pp", "sharding", "mp", "sep"],
2533
):
26-
arr = np.arange(0, dp_degree * pp_degree * sharding_degree * mp_degree * sep_degree).reshape(
27-
[dp_degree, pp_degree, sharding_degree, mp_degree, sep_degree]
28-
)
29-
30-
dp_rank, pp_rank, sharding_rank, mp_rank, sep_rank = np.where(arr == device_rank)
31-
dp_rank = dp_rank[0]
32-
pp_rank = pp_rank[0]
33-
sharding_rank = sharding_rank[0]
34-
mp_rank = mp_rank[0]
35-
sep_rank = sep_rank[0]
36-
37-
self.world = GroupInfo(size=world_size, rank=device_rank, world=list(range(0, world_size)))
34+
assert set(order) == {"dp", "pp", "sharding", "mp", "sep"}, f"Illegal order : {order}"
35+
self.order = order
3836

39-
sep_world = arr[dp_rank, pp_rank, sharding_rank, mp_rank, :]
40-
self.sep_info = GroupInfo(size=len(sep_world), rank=sep_rank, world=sep_world.tolist())
37+
degree_map = {
38+
"dp": dp_degree,
39+
"pp": pp_degree,
40+
"sharding": sharding_degree,
41+
"mp": mp_degree,
42+
"sep": sep_degree,
43+
}
44+
shape = [degree_map[key] for key in self.order]
4145

42-
mp_world = arr[dp_rank, pp_rank, sharding_rank, :, sep_rank]
43-
self.mp_info = GroupInfo(size=len(mp_world), rank=mp_rank, world=mp_world.tolist())
46+
arr = np.arange(0, dp_degree * pp_degree * sharding_degree * mp_degree * sep_degree).reshape(shape)
47+
ranks = [rank[0] for rank in np.where(arr == device_rank)]
4448

45-
sharding_world = arr[dp_rank, pp_rank, :, mp_rank, sep_rank]
46-
self.sharding_info = GroupInfo(size=len(sharding_world), rank=sharding_rank, world=sharding_world.tolist())
47-
48-
pp_world = arr[dp_rank, :, sharding_rank, mp_rank, sep_rank]
49-
self.pp_info = GroupInfo(size=len(pp_world), rank=pp_rank, world=pp_world.tolist())
50-
51-
dp_world = arr[:, pp_rank, sharding_rank, mp_rank, sep_rank]
52-
self.dp_info = GroupInfo(size=len(dp_world), rank=dp_rank, world=dp_world.tolist())
49+
self.world = GroupInfo(size=world_size, rank=device_rank, world=list(range(0, world_size)))
50+
worlds = []
51+
for i in range(len(ranks)):
52+
indexs = tuple(ranks[:i] + [slice(None)] + ranks[(i + 1) :])
53+
worlds.append(arr[indexs])
54+
55+
for i, key in enumerate(self.order):
56+
if key == "dp":
57+
self.dp_info = GroupInfo(size=len(worlds[i]), rank=ranks[i], world=worlds[i].tolist())
58+
elif key == "pp":
59+
self.pp_info = GroupInfo(size=len(worlds[i]), rank=ranks[i], world=worlds[i].tolist())
60+
elif key == "sharding":
61+
self.sharding_info = GroupInfo(size=len(worlds[i]), rank=ranks[i], world=worlds[i].tolist())
62+
elif key == "mp":
63+
self.mp_info = GroupInfo(size=len(worlds[i]), rank=ranks[i], world=worlds[i].tolist())
64+
elif key == "sep":
65+
self.sep_info = GroupInfo(size=len(worlds[i]), rank=ranks[i], world=worlds[i].tolist())
5366

5467
self.is_last = self.pp_info.rank == self.pp_info.size - 1
5568

5669
data_arr = np.arange(0, dp_degree * sharding_degree).reshape([dp_degree, sharding_degree])
57-
data_arr = np.expand_dims(data_arr, axis=1).repeat(pp_degree, axis=1)
58-
data_arr = np.expand_dims(data_arr, axis=3).repeat(mp_degree, axis=3)
59-
data_arr = np.expand_dims(data_arr, axis=4).repeat(sep_degree, axis=4)
70+
for i, key in enumerate(self.order):
71+
if key != "dp" and key != "sharding":
72+
data_arr = np.expand_dims(data_arr, axis=i).repeat(degree_map[key], axis=i)
6073

6174
self.data_info = GroupInfo(
6275
size=int(self.dp_info.size * self.sharding_info.size),
@@ -68,4 +81,4 @@ def __init__(
6881
self.data_inner_times = self.world.size // self.data_info.size
6982

7083
def __repr__(self):
71-
return f"dp_info:\n\t {self.dp_info}, \npp_info:\n\t {self.pp_info}, \nsharding_info:\n\t {self.sharding_info}, \nmp_info:\n\t {self.mp_info}, \nsep_info:\n\t {self.sep_info}\ndata_info:\n\t {self.data_info}"
84+
return f"dp_info:\n\t {self.dp_info}, \npp_info:\n\t {self.pp_info}, \nsharding_info:\n\t {self.sharding_info}, \nmp_info:\n\t {self.mp_info}, \nsep_info:\n\t {self.sep_info}, \ndata_info:\n\t {self.data_info}, \norder:\n\t {self.order}"

paddlenlp/trainer/training_args.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,10 @@ def __post_init__(self):
931931
self.pipeline_parallel_degree = -1
932932
self.sep_parallel_degree = -1
933933

934+
if self.hybrid_parallel_topo_order is None:
935+
self.hybrid_parallel_topo_order = "pp_first"
936+
assert self.hybrid_parallel_topo_order in ["pp_first", "sharding_first"]
937+
934938
if self.use_hybrid_parallel and self.enable_auto_parallel:
935939
self.use_hybrid_parallel = False
936940

@@ -1058,10 +1062,6 @@ def __post_init__(self):
10581062
"by current version of Paddle. Please try latest develop Paddle."
10591063
)
10601064

1061-
if self.hybrid_parallel_topo_order is None:
1062-
self.hybrid_parallel_topo_order = "pp_first"
1063-
assert self.hybrid_parallel_topo_order in ["pp_first", "sharding_first"]
1064-
10651065
def is_segment_parallel_supported():
10661066
import inspect
10671067

@@ -1317,17 +1317,27 @@ def is_segment_parallel_supported():
13171317
recompute.refined_ops_patterns.append(eval(pattern))
13181318

13191319
self.strategy = strategy
1320-
order = ["dp", "pp", "mp"]
1321-
degree = [self.data_parallel_degree, self.pipeline_parallel_degree, self.tensor_parallel_degree]
1320+
if self.hybrid_parallel_topo_order == "pp_first":
1321+
order = ["pp", "dp", "mp"]
1322+
degree = [self.pipeline_parallel_degree, self.data_parallel_degree, self.tensor_parallel_degree]
1323+
elif self.hybrid_parallel_topo_order == "sharding_first":
1324+
order = ["dp", "pp", "mp"]
1325+
degree = [self.data_parallel_degree, self.pipeline_parallel_degree, self.tensor_parallel_degree]
13221326
mesh_dims = list(zip(order, degree))
13231327
fleet.auto.create_mesh(mesh_dims)
13241328

13251329
# init hcg for communication in trainer
1330+
if self.hybrid_parallel_topo_order == "pp_first":
1331+
order = ["pp", "dp", "sharding", "sep", "mp"]
1332+
elif self.hybrid_parallel_topo_order == "sharding_first":
1333+
order = ["dp", "sharding", "pp", "sep", "mp"]
1334+
13261335
strategy = fleet.DistributedStrategy()
13271336
strategy.hybrid_configs = {
13281337
"dp_degree": self.data_parallel_degree,
13291338
"mp_degree": self.tensor_parallel_degree,
13301339
"pp_degree": self.pipeline_parallel_degree,
1340+
"order": order,
13311341
}
13321342
fleet.init(is_collective=True, strategy=strategy)
13331343

paddlenlp/transformers/llama/modeling_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def swiglu(x, y=None):
8585
def get_mesh(pp_idx=0):
8686
mesh = fleet.auto.get_mesh()
8787
if "pp" in mesh.dim_names:
88-
mesh = mesh.get_mesh_with_dim("pp")[pp_idx]
88+
mesh = mesh.get_mesh_with_dim("pp", pp_idx)
8989
return mesh
9090

9191

scripts/distribute/ci_case_auto.sh

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() {
965965
ips=-1
966966
mem=-1
967967
echo "result: loss=$loss ips=$ips mem=$mem"
968-
loss_base=9.42011845
968+
loss_base=9.42011833
969969
ips_base=-1
970970
mem_base=-1
971971
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1031,7 +1031,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() {
10311031
ips=-1
10321032
mem=-1
10331033
echo "result: loss=$loss ips=$ips mem=$mem"
1034-
loss_base=9.44299495
1034+
loss_base=9.44299471
10351035
ips_base=-1
10361036
mem_base=-1
10371037
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1093,7 +1093,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() {
10931093
--data_impl "mmap" \
10941094
--enable_auto_parallel 1 \
10951095
>>${log_path}/$FUNCNAME 2>&1
1096-
loss=`cat $case_log_dir/workerlog.2 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
1096+
loss=`cat $case_log_dir/workerlog.4 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
10971097
ips=-1
10981098
mem=-1
10991099
echo "result: loss=$loss ips=$ips mem=$mem"
@@ -1161,7 +1161,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
11611161
--data_impl "mmap" \
11621162
--enable_auto_parallel 1 \
11631163
>>${log_path}/$FUNCNAME 2>&1
1164-
loss=`cat $case_log_dir/workerlog.3 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
1164+
loss=`cat $case_log_dir/workerlog.4 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
11651165
ips=-1
11661166
mem=-1
11671167
echo "result: loss=$loss ips=$ips mem=$mem"
@@ -1230,7 +1230,7 @@ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
12301230
--data_impl "mmap" \
12311231
--enable_auto_parallel 1 \
12321232
>>${log_path}/$FUNCNAME 2>&1
1233-
loss=`cat $case_log_dir/workerlog.3 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
1233+
loss=`cat $case_log_dir/workerlog.4 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
12341234
ips=-1
12351235
mem=-1
12361236
echo "result: loss=$loss ips=$ips mem=$mem"
@@ -1301,7 +1301,7 @@ function llama_dygraph_auto_bs8_fp32_DP2() {
13011301
ips=-1
13021302
mem=-1
13031303
echo "result: loss=$loss ips=$ips mem=$mem"
1304-
loss_base=9.52781677
1304+
loss_base=9.53389835
13051305
ips_base=-1
13061306
mem_base=-1
13071307
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1368,7 +1368,7 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() {
13681368
ips=-1
13691369
mem=-1
13701370
echo "result: loss=$loss ips=$ips mem=$mem"
1371-
loss_base=9.40659046
1371+
loss_base=9.39066124
13721372
ips_base=-1
13731373
mem_base=-1
13741374
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1435,7 +1435,7 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
14351435
ips=-1
14361436
mem=-1
14371437
echo "result: loss=$loss ips=$ips mem=$mem"
1438-
loss_base=9.38319206
1438+
loss_base=9.38235474
14391439
ips_base=-1
14401440
mem_base=-1
14411441
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1503,7 +1503,7 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
15031503
ips=-1
15041504
mem=-1
15051505
echo "result: loss=$loss ips=$ips mem=$mem"
1506-
loss_base=9.38341904
1506+
loss_base=9.38257694
15071507
ips_base=-1
15081508
mem_base=-1
15091509
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}

0 commit comments

Comments
 (0)