Skip to content

Commit dc2b7c9

Browse files
[AutoParallel]:parallel api support tp/pp & add share_embedding test (#70182)
1 parent 4410d26 commit dc2b7c9

File tree

8 files changed

+199
-65
lines changed

8 files changed

+199
-65
lines changed

python/paddle/distributed/auto_parallel/intermediate/parallel_base.py

Lines changed: 108 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def __init__(self, model):
118118
self.tp_parallelizer = None
119119
self.sharding_parallelizer = None
120120
self.model = None
121-
121+
self.share_param_list = {}
122+
self.layer_param_placements = {}
122123
if isinstance(model, ParallelModel):
123124
self.pp_parallelizer = model.pp_parallelizer
124125
self.tp_parallelizer = model.tp_parallelizer
@@ -147,8 +148,9 @@ def parallelize_model(self):
147148

148149
if self.tp_parallelizer is not None:
149150
assert callable(self.tp_parallelizer)
150-
self.model = self.tp_parallelizer(self.model)
151-
151+
self.model, self.layer_param_placements = self.tp_parallelizer(
152+
self.model
153+
)
152154
if self.sharding_parallelizer is not None:
153155
assert callable(self.sharding_parallelizer)
154156
self.model = self.sharding_parallelizer(self.model)
@@ -157,36 +159,119 @@ def parallelize_model(self):
157159

158160
return self.model
159161

162+
def _process_share_weight_layer(
163+
self, layer, origin_weight, param_name, param_placements
164+
):
165+
ipp = (
166+
layer.pipeline_stage_index
167+
if hasattr(layer, "pipeline_stage_index")
168+
else 0
169+
)
170+
171+
def create_pre_hook(origin_weight, param_name):
172+
def forward_pre_hook(layer, input):
173+
setattr(
174+
layer,
175+
param_name,
176+
None,
177+
)
178+
delattr(layer, param_name)
179+
mesh = self.get_mesh(ipp)
180+
share_weight = dist.reshard(
181+
origin_weight,
182+
mesh,
183+
param_placements,
184+
)
185+
setattr(
186+
layer,
187+
param_name,
188+
share_weight,
189+
)
190+
191+
return forward_pre_hook
192+
193+
def create_post_hook(origin_weight, param_name):
194+
def forward_post_hook(layer, input, output):
195+
setattr(
196+
layer,
197+
param_name,
198+
origin_weight,
199+
)
200+
201+
return forward_post_hook
202+
203+
layer.register_forward_pre_hook(
204+
create_pre_hook(origin_weight, param_name)
205+
)
206+
layer.register_forward_post_hook(
207+
create_post_hook(origin_weight, param_name)
208+
)
209+
160210
def _shard_all_param(self, model):
161211
param_name_to_shard_param = {}
212+
param_name_to_pp_stage = {}
162213

163214
def shard_layer_param(layer):
164215
if self.pp_parallelizer is not None:
165216
assert hasattr(layer, "pipeline_stage_index")
166217
for param_name in list(layer._parameters.keys()):
167218
param = getattr(layer, param_name)
168-
if param is not None and not param.is_dist():
219+
if param is not None:
169220
param_full_name = param.name
170-
if param_full_name in param_name_to_shard_param:
171-
setattr(
172-
layer,
173-
param_name,
174-
param_name_to_shard_param[param_full_name],
175-
)
221+
ipp = (
222+
layer.pipeline_stage_index
223+
if hasattr(layer, "pipeline_stage_index")
224+
else 0
225+
)
226+
mesh = self.get_mesh(ipp)
227+
param_placements = [
228+
dist.Replicate() for _ in range(len(mesh._shape))
229+
]
230+
if layer in self.layer_param_placements:
231+
if param_name in self.layer_param_placements[layer]:
232+
param_placements = (
233+
self.layer_param_placements[layer][param_name]
234+
if self.layer_param_placements[layer][
235+
param_name
236+
]
237+
is not None
238+
else param_placements
239+
)
240+
if not param.is_dist():
241+
if param_full_name in param_name_to_shard_param:
242+
setattr(
243+
layer,
244+
param_name,
245+
param_name_to_shard_param[param_full_name],
246+
)
247+
if ipp != param_name_to_pp_stage[param_full_name]:
248+
self._process_share_weight_layer(
249+
layer,
250+
param_name_to_shard_param[param_full_name],
251+
param_name,
252+
param_placements,
253+
)
254+
else:
255+
param = dist.shard_tensor(
256+
param, mesh, param_placements
257+
)
258+
param_name_to_shard_param[param_full_name] = param
259+
param_name_to_pp_stage[param_full_name] = ipp
260+
setattr(layer, param_name, param)
176261
else:
177-
ipp = (
178-
layer.pipeline_stage_index
179-
if hasattr(layer, "pipeline_stage_index")
180-
else 0
181-
)
182-
mesh = self.get_mesh(ipp)
183-
param = dist.shard_tensor(
184-
param,
185-
mesh,
186-
[dist.Replicate() for _ in range(len(mesh._shape))],
187-
)
188-
param_name_to_shard_param[param_full_name] = param
189-
setattr(layer, param_name, param)
262+
if (
263+
param_full_name in param_name_to_shard_param
264+
and ipp != param_name_to_pp_stage[param_full_name]
265+
):
266+
self._process_share_weight_layer(
267+
layer,
268+
param_name_to_shard_param[param_full_name],
269+
param_name,
270+
param_placements,
271+
)
272+
elif param_full_name not in param_name_to_shard_param:
273+
param_name_to_shard_param[param_full_name] = param
274+
param_name_to_pp_stage[param_full_name] = ipp
190275

191276
for name, layer in model.named_sublayers():
192277
shard_layer_param(layer)

python/paddle/distributed/auto_parallel/intermediate/tensor_parallel.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def c_concat(x, process_mesh, need_transpose):
8282

8383
class PlanBase:
8484
def __init__(self):
85-
pass
85+
self.share_param_list = {}
8686

8787
def apply(self, layer, process_mesh, shard_weight, shard_bias):
8888
raise NotImplementedError("Don't call the PlanBase directly.")
@@ -143,6 +143,7 @@ def apply(self, layer, process_mesh, shard_weight=True, shard_bias=True):
143143
index = process_mesh.dim_names.index('mp') # get the axis for the split
144144
size = len(process_mesh.shape)
145145
placement = [dist.Replicate() for _ in range(size)]
146+
param_placements = {}
146147
assert isinstance(layer, paddle.nn.Layer)
147148
if not isinstance(layer, (paddle.nn.Linear, paddle.nn.Embedding)):
148149
logging.warning(
@@ -157,20 +158,39 @@ def apply(self, layer, process_mesh, shard_weight=True, shard_bias=True):
157158
):
158159
placement[index] = dist.Shard(1)
159160
assert len(layer.weight.shape) == 2
160-
layer.weight = dist.shard_tensor(
161-
layer.weight,
162-
process_mesh,
163-
placement,
164-
)
161+
# NOTE(zhangweilong):for share parameter, the parameter should be handled uniformly in the end
162+
if (
163+
self.share_param_list is not None
164+
and layer.weight.name in self.share_param_list
165+
and self.share_param_list[layer.weight.name] > 1
166+
):
167+
param_placements.update({"weight": placement})
168+
else:
169+
layer.weight = dist.shard_tensor(
170+
layer.weight,
171+
process_mesh,
172+
placement,
173+
)
165174
if hasattr(layer, "bias") and layer.bias is not None and shard_bias:
166175
placement[index] = dist.Shard(0)
167176
assert len(layer.bias.shape) == 1
168-
layer.bias = dist.shard_tensor(layer.bias, process_mesh, placement)
177+
# NOTE(zhangweilong):for share parameter, the parameter should be handled uniformly in the end
178+
if (
179+
self.share_param_list is not None
180+
and layer.bias.name in self.share_param_list
181+
and self.share_param_list[layer.bias.name] > 1
182+
):
183+
param_placements.update({"bias": placement})
184+
else:
185+
layer.bias = dist.shard_tensor(
186+
layer.bias, process_mesh, placement
187+
)
169188

170189
if self.gather_output:
171190
layer.register_forward_post_hook(
172191
self.gather_output_hook(process_mesh)
173192
)
193+
return param_placements
174194

175195

176196
class RowWiseParallel(PlanBase):
@@ -185,7 +205,7 @@ class RowWiseParallel(PlanBase):
185205
186206
Args:
187207
is_input_parallel (bool): Whether the input is a local tensor or a global tensor. If the input is a
188-
global tensor, an extra split will be called. The default value is `True`
208+
global tensor, an extra split will be called. The default value is `True`,
189209
which means the input is a local tensor.
190210
191211
Examples:
@@ -225,6 +245,7 @@ def apply(self, layer, process_mesh, shard_weight=True, shard_bias=False):
225245
size = len(process_mesh.shape)
226246
placement = [dist.Replicate() for _ in range(size)]
227247
placement[index] = dist.Shard(0)
248+
param_placements = {}
228249
assert isinstance(layer, paddle.nn.Layer)
229250
if not isinstance(layer, (paddle.nn.Linear, paddle.nn.Embedding)):
230251
logging.warning(
@@ -238,13 +259,22 @@ def apply(self, layer, process_mesh, shard_weight=True, shard_bias=False):
238259
and shard_weight
239260
):
240261
assert len(layer.weight.shape) == 2
241-
layer.weight = dist.shard_tensor(
242-
layer.weight,
243-
process_mesh,
244-
placement,
245-
)
262+
# NOTE(zhangweilong):for share parameter, the parameter should be handled uniformly in the end
263+
if (
264+
self.share_param_list is not None
265+
and layer.weight.name in self.share_param_list
266+
and self.share_param_list[layer.weight.name] > 1
267+
):
268+
param_placements.update({"weight": placement})
269+
else:
270+
layer.weight = dist.shard_tensor(
271+
layer.weight,
272+
process_mesh,
273+
placement,
274+
)
246275
if not self.is_input_parallel:
247276
layer.register_forward_pre_hook(self.split_input_hook(process_mesh))
277+
return param_placements
248278

249279

250280
class PrepareLayerInput(PlanBase):
@@ -626,20 +656,35 @@ def match_layer(self, name):
626656
def tensor_parallelizer_fn(self, model):
627657
if self.parallelize_plan is None:
628658
return
659+
layer_param_placements = {}
660+
share_param_list = {}
661+
for name, layer in model.named_sublayers():
662+
for param_name in list(layer._parameters.keys()):
663+
param = getattr(layer, param_name)
664+
if param.name not in share_param_list:
665+
share_param_list[param.name] = 1
666+
continue
667+
share_param_list[param.name] += 1
629668
for name, layer in model.named_sublayers():
630669
plans = self.match_layer(name)
670+
layer_param_placements[layer] = {}
631671
if len(plans) > 0:
632672
pp_idx = getattr(layer, "pipeline_stage_index", 0)
633673
for plan in plans:
634674
real_plan, shard_weight, shard_bias = plan
635675
for p in real_plan:
636-
p.apply(
676+
p.share_param_list = share_param_list
677+
param_placements = p.apply(
637678
layer,
638679
self.get_mesh(pp_idx),
639680
shard_weight,
640681
shard_bias,
641682
)
642-
return model
683+
if param_placements is not None and param_placements:
684+
layer_param_placements[layer].update(
685+
param_placements
686+
)
687+
return model, layer_param_placements
643688

644689

645690
def tensor_parallel(model, optimizer=None, config=None):

test/auto_parallel/hybrid_strategy/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,23 +135,23 @@ if((WITH_GPU) AND (LINUX))
135135
ENVS
136136
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
137137
set_tests_properties(test_parallel_api_with_llama_1d
138-
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=HYBRID")
138+
PROPERTIES TIMEOUT "400" LABELS "RUN_TYPE=HYBRID")
139139
endif()
140140
if((WITH_GPU) AND (LINUX))
141141
py_test_modules(
142142
test_parallel_api_with_llama_2d MODULES test_parallel_api_with_llama_2d
143143
ENVS
144144
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
145145
set_tests_properties(test_parallel_api_with_llama_2d
146-
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=HYBRID")
146+
PROPERTIES TIMEOUT "400" LABELS "RUN_TYPE=HYBRID")
147147
endif()
148148
if((WITH_GPU) AND (LINUX))
149149
py_test_modules(
150150
test_parallel_api_with_llama_3d MODULES test_parallel_api_with_llama_3d
151151
ENVS
152152
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
153153
set_tests_properties(test_parallel_api_with_llama_3d
154-
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=HYBRID")
154+
PROPERTIES TIMEOUT "400" LABELS "RUN_TYPE=HYBRID")
155155
endif()
156156
if((WITH_GPU) AND (LINUX))
157157
py_test_modules(

test/auto_parallel/hybrid_strategy/parallel_api.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def init_dist_env(self):
160160
global_mesh = dist.ProcessMesh(mesh_arr, dim_names)
161161
dist.auto_parallel.set_mesh(global_mesh)
162162

163-
def check_mp(self, layer):
163+
def check_mp(self, layer, share_embedding):
164164
if self.mp == 1:
165165
return
166166
for name, sub_layer in layer.named_sublayers():
@@ -174,12 +174,14 @@ def check_mp(self, layer):
174174
dist.Replicate(),
175175
dist.Shard(0),
176176
]
177+
if 'gate_proj' in name or 'up_proj' in name:
178+
assert sub_layer.weight.placements == [
179+
dist.Replicate(),
180+
dist.Shard(1),
181+
]
177182
if (
178-
'gate_proj' in name
179-
or 'up_proj' in name
180-
or 'embed_tokens' in name
181-
or 'lm_head' in name
182-
):
183+
'embed_tokens' in name or 'lm_head' in name
184+
) and not share_embedding:
183185
assert sub_layer.weight.placements == [
184186
dist.Replicate(),
185187
dist.Shard(1),
@@ -196,7 +198,7 @@ def check_mp(self, layer):
196198
dist.Shard(0),
197199
]
198200

199-
def parallel_model(self, layer):
201+
def parallel_model(self, layer, share_embedding=False):
200202
dp_config = None
201203
mp_config = None
202204
pp_config = None
@@ -306,7 +308,7 @@ def parallel_model(self, layer):
306308
optimizer,
307309
config=config,
308310
)
309-
self.check_mp(layer)
311+
self.check_mp(layer, share_embedding)
310312
return layer, optimizer, lr_scheduler
311313

312314
def run_llama(
@@ -322,7 +324,9 @@ def run_llama(
322324
self.config, share_embedding, position_embedding
323325
)
324326

325-
model, optimizer, lr_scheduler = self.parallel_model(model)
327+
model, optimizer, lr_scheduler = self.parallel_model(
328+
model, share_embedding
329+
)
326330

327331
criterion = LlamaPretrainingCriterion(self.config)
328332

0 commit comments

Comments
 (0)