Skip to content

Commit 6aa71b4

Browse files
authored
[Auto Parallel] fix sharding bug && automatically mark split points (#69426)
1 parent a8aca59 commit 6aa71b4

File tree

6 files changed

+134
-33
lines changed

6 files changed

+134
-33
lines changed

python/paddle/distributed/auto_parallel/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,7 @@ def _set_and_check_sharding_prop_from_param(self):
10551055
'dp' in self._shard_fn._mesh.dim_names
10561056
):
10571057
self._sharding_degree = self._shard_fn._mesh.get_dim_size('dp')
1058+
self._sharding_mesh_axis = 0
10581059
else:
10591060
param_list = self._inner_opt._parameter_list
10601061
for param in param_list:

python/paddle/distributed/auto_parallel/intermediate/pipeline_parallel.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,18 @@
1313
# limitations under the License.
1414

1515
import itertools
16+
import re
1617
from collections import OrderedDict
1718
from enum import Enum
1819

1920
import paddle.distributed as dist
2021
from paddle.distributed import fleet
22+
from paddle.distributed.utils.log_utils import get_logger
2123

2224
from .parallel_base import ParallelModel, ParallelOptimizer, is_tensor
2325

26+
logger = get_logger("INFO", __name__)
27+
2428

2529
class SplitPoint(Enum):
2630
BEGINNING = 0
@@ -137,7 +141,11 @@ def pipeline_parallel(model, optimizer, split_spec, mesh=None, dimension=None):
137141
Args:
138142
model (paddle.nn.Layer): A single card model to be distributed
139143
optimizer (paddle.optimizer.Optimizer): An optimizer to be distributed
140-
split_spec (OrderedDict): Pipeline parallel split point, the order of the keys is the order of the pipeline stage
144+
split_spec (OrderedDict|dict|str): The pipeline parallel split point.
145+
if split_spec is a string, such as "llama.layer", Then the layer with same prefix a will be divided equally according to the size of pipeline degree.
146+
if split_spec is a OrderedDict|dict, key is the layer name, and the value is the split position that can be SplitPoint.BEGINNING or SplitPoint.END, the order of the keys is the order of the pipeline stage.
147+
NOTE: dict is also ordered after python3.7, so use dict at this time.
148+
the order of the keys is the order of the pipeline stage
141149
mesh (ProcessMesh): A ProcessMesh Object.
142150
dimension (int|str): The mesh dimension to pipeline the model.
143151
@@ -158,7 +166,40 @@ def pipeline_parallel(model, optimizer, split_spec, mesh=None, dimension=None):
158166
"Specifying a custom mesh is not supported currently"
159167
)
160168

161-
model = PipelineParallel(model, split_spec)
169+
if isinstance(split_spec, str):
170+
# match layer_name with split_spec following by a dot and numbers and no other characters
171+
# such as split_spec = "llama.layer", then llama.layer.0 is matched, llama.layer.0.mlp is not matched
172+
pattern = rf"{split_spec}\.\d+$"
173+
matched_layer_name = [
174+
name
175+
for name, _ in model.named_sublayers()
176+
if re.match(pattern, name)
177+
]
178+
179+
pp_size = mesh.get_dim_size("pp")
180+
layer_num = len(matched_layer_name)
181+
assert (
182+
layer_num > 0
183+
), "No layer match the split_spec, please check its correctness"
184+
assert (
185+
layer_num % pp_size == 0
186+
), f"The number of layers must be divisible by the pp size, but got {layer_num} and {pp_size}"
187+
layers_per_rank = layer_num // pp_size
188+
split_spec_dict = OrderedDict(
189+
[
190+
(
191+
f"{split_spec}.{i * layers_per_rank - 1}",
192+
SplitPoint.END,
193+
)
194+
for i in range(1, pp_size)
195+
]
196+
)
197+
else:
198+
split_spec_dict = split_spec
199+
200+
logger.info(f"split_spec_dict: {split_spec_dict}")
201+
202+
model = PipelineParallel(model, split_spec_dict)
162203
if optimizer is not None:
163204
optimizer = ParallelOptimizer(optimizer)
164205

test/auto_parallel/hybrid_strategy/parallel_api.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import logging
1515
import os
1616
import random
17-
from collections import OrderedDict
1817
from functools import reduce
1918

2019
import numpy as np
@@ -26,9 +25,6 @@
2625
from paddle.distributed.auto_parallel.intermediate.parallelize import (
2726
parallelize,
2827
)
29-
from paddle.distributed.auto_parallel.intermediate.pipeline_parallel import (
30-
SplitPoint,
31-
)
3228
from paddle.distributed.auto_parallel.intermediate.tensor_parallel import (
3329
ColWiseParallel,
3430
RowWiseParallel,
@@ -145,6 +141,10 @@ def __init__(self):
145141
if os.getenv("prepare_input_output") == "true":
146142
self.sequence_parallel = True
147143

144+
num_hidden_layers = os.getenv("num_hidden_layers")
145+
if num_hidden_layers:
146+
self.config.num_hidden_layers = int(num_hidden_layers)
147+
148148
seed = int(os.getenv("seed", 2024))
149149
np.random.seed(seed)
150150
random.seed(seed)
@@ -204,17 +204,12 @@ def parallel_model(self, layer, optimizer=None):
204204
mp_config = None
205205
pp_config = None
206206
if self.pp > 1:
207-
decoders_per_rank = self.config.num_hidden_layers // self.pp
208-
split_spec = OrderedDict(
209-
[
210-
(
211-
f"llama.layers.{i * decoders_per_rank - 1}",
212-
SplitPoint.END,
213-
)
214-
for i in range(1, self.pp)
215-
]
216-
)
217-
pp_config = {'split_spec': split_spec}
207+
# decoders_per_rank = self.config.num_hidden_layers // self.pp
208+
# split_spec = {
209+
# f"llama.layers.{i * decoders_per_rank - 1}": SplitPoint.END
210+
# for i in range(1, self.pp)
211+
# }
212+
pp_config = {'split_spec': "llama.layers"}
218213
if self.dp > 1:
219214
dp_config = {'sharding_level': self.level}
220215
if self.mp > 1:

test/auto_parallel/hybrid_strategy/test_parallel_api_with_llama_1d.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,9 @@ def setUp(self):
3333
"backend": ["gpu"],
3434
"amp": ["true"],
3535
"amp_level": ["O2"],
36-
"amp_dtype": [
37-
"bfloat16",
38-
],
39-
"amp_master_grad": [
40-
"False",
41-
],
42-
"sharding_stage": [
43-
"1",
44-
],
36+
"amp_dtype": ["bfloat16"],
37+
"amp_master_grad": ["False"],
38+
"sharding_stage": ["0", "1"],
4539
}
4640

4741
def test_simple_net_dp2(self):
@@ -73,12 +67,9 @@ def setUp(self):
7367
"backend": ["gpu"],
7468
"amp": ["true"],
7569
"amp_level": ["O2"],
76-
"amp_dtype": [
77-
"bfloat16",
78-
],
79-
"amp_master_grad": [
80-
"False",
81-
],
70+
"amp_dtype": ["bfloat16"],
71+
"amp_master_grad": ["False"],
72+
"num_hidden_layers": ["2", "4"],
8273
}
8374

8475
def test_simple_net_pp2(self):

test/auto_parallel/hybrid_strategy/test_parallel_api_with_llama_2d.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,78 @@ def test_simple_net_mp2_pp2(self):
5454
ckpt_path.cleanup()
5555

5656

57+
class TestDPPPAPI(test_base.CommunicationTestDistBase):
58+
def setUp(self):
59+
super().setUp(num_of_devices=4, timeout=120, nnode=1)
60+
self._default_envs = {
61+
"dtype": "float32",
62+
"seed": "2023",
63+
"dp": "2",
64+
"mp": "1",
65+
"pp": "2",
66+
"acc_step": "2",
67+
}
68+
self._changeable_envs = {
69+
"backend": ["gpu"],
70+
"amp": ["true"],
71+
"amp_level": ["O2"],
72+
"amp_dtype": ["bfloat16"],
73+
"amp_master_grad": ["true"],
74+
"use_lazy_init": ["true"],
75+
"num_hidden_layers": ["2", "4"],
76+
"sharding_stage": ["0"],
77+
}
78+
79+
def test_simple_net_dp2_pp2(self):
80+
envs_list = test_base.gen_product_envs_list(
81+
self._default_envs, self._changeable_envs
82+
)
83+
for envs in envs_list:
84+
ckpt_path = tempfile.TemporaryDirectory()
85+
envs["ckpt_path"] = ckpt_path.name
86+
self.run_test_case(
87+
"parallel_api.py",
88+
user_defined_envs=envs,
89+
)
90+
ckpt_path.cleanup()
91+
92+
93+
class TestDPMPAPI(test_base.CommunicationTestDistBase):
94+
def setUp(self):
95+
super().setUp(num_of_devices=4, timeout=120, nnode=1)
96+
self._default_envs = {
97+
"dtype": "float32",
98+
"seed": "2023",
99+
"dp": "2",
100+
"mp": "2",
101+
"pp": "1",
102+
"acc_step": "2",
103+
}
104+
self._changeable_envs = {
105+
"backend": ["gpu"],
106+
"amp": ["true"],
107+
"amp_level": ["O2"],
108+
"amp_dtype": ["bfloat16"],
109+
"amp_master_grad": ["true"],
110+
"use_lazy_init": ["true"],
111+
"sequence_parallel": ["true"],
112+
"prepare_input_output": ["false"],
113+
"sharding_stage": ["0"],
114+
}
115+
116+
def test_simple_net_mp2_pp2(self):
117+
envs_list = test_base.gen_product_envs_list(
118+
self._default_envs, self._changeable_envs
119+
)
120+
for envs in envs_list:
121+
ckpt_path = tempfile.TemporaryDirectory()
122+
envs["ckpt_path"] = ckpt_path.name
123+
self.run_test_case(
124+
"parallel_api.py",
125+
user_defined_envs=envs,
126+
)
127+
ckpt_path.cleanup()
128+
129+
57130
if __name__ == "__main__":
58131
unittest.main()

test/auto_parallel/hybrid_strategy/test_parallel_api_with_llama_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def setUp(self):
3838
"use_lazy_init": ["true"],
3939
"sequence_parallel": ["true"],
4040
"prepare_input_output": ["false"],
41-
"sharding_stage": ["0"],
41+
"sharding_stage": ["0", "1"],
4242
}
4343

4444
def test_simple_net_dp2_mp2_pp2(self):

0 commit comments

Comments
 (0)