Skip to content

Commit dbf395f

Browse files
authored
Cherry pick some changes from incubate branch (#8862)
* Fix sharding reshard bug and support reshard in different partition ways (#8837) * Fix sharding reshard bug and support reshard in different partition ways * fix some bugs * refine codes * revert some useless codes * add log * Fix TP sync (#8846) * fix sharding reshard compatiblity for wrong master weight name (#8851) * change print to logger.info * remove useless f-string
1 parent c4d1abf commit dbf395f

File tree

4 files changed

+128
-18
lines changed

4 files changed

+128
-18
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,17 @@
195195
g_cpu_optimizer_state_dict = {}
196196

197197

198-
def _save_func(obj, path, saved_signal_path, protocol):
198+
def _save_func(obj, name_mapping, path, saved_signal_path, protocol):
199+
if isinstance(obj, dict):
200+
for k, v in obj.items():
201+
if k == "master_weights" and isinstance(v, dict):
202+
for kk, vv in v.items():
203+
if isinstance(vv, paddle.Tensor):
204+
vv.name = name_mapping["master_weights"][kk]
205+
else:
206+
if k in name_mapping and isinstance(v, paddle.Tensor):
207+
v.name = name_mapping[k]
208+
199209
paddle.save(obj, path, protocol)
200210
# dump savd_siganl
201211
with open(saved_signal_path, mode="w+") as f:
@@ -228,17 +238,18 @@ def clear_async_save_task_queue():
228238
def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol=4):
229239
global g_cpu_optimizer_state_dict
230240
g_cpu_optimizer_state_dict.clear()
241+
name_mapping = {"master_weights": {}}
231242
for k, v in optimizer_state_dict.items():
232243
if k == "master_weights":
233244
g_cpu_optimizer_state_dict[k] = {}
234245
for kk, vv in v.items():
235-
tensor_name = vv.name
236246
g_cpu_optimizer_state_dict[k][kk] = vv.pin_memory()
237-
g_cpu_optimizer_state_dict[k][kk].name = tensor_name
247+
name_mapping[k][kk] = vv.name
238248
elif k == "LR_Scheduler":
239249
g_cpu_optimizer_state_dict[k] = copy.deepcopy(v)
240250
else:
241251
g_cpu_optimizer_state_dict[k] = v.pin_memory()
252+
name_mapping[k] = v.name
242253
paddle.device.synchronize()
243254
clear_async_save_task_queue()
244255

@@ -248,7 +259,9 @@ def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol
248259
def start_process():
249260
nonlocal attempt
250261
try:
251-
p = ctx.Process(target=_save_func, args=(g_cpu_optimizer_state_dict, path, saved_signal_path, protocol))
262+
p = ctx.Process(
263+
target=_save_func, args=(g_cpu_optimizer_state_dict, name_mapping, path, saved_signal_path, protocol)
264+
)
252265
p.start()
253266
return p
254267
except Exception as e:

paddlenlp/trainer/training_args.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,15 +1171,23 @@ def split_parallel_config(parallel_config):
11711171
# sync_param_name = [""] matches any parameter name.
11721172
# If sync_param, sync_grad and sync_moment are not set, the default value in Paddle is :
11731173
# sync_param = True, sync_grad = False, sync_moment = False, sync_param_name = ["embedding", "layer_norm", ".b_"].
1174+
1175+
if sync_param or sync_grad or sync_moment:
1176+
logger.info("setting sync_param_name")
1177+
strategy.sync_param_name = [""]
1178+
11741179
if sync_param:
1180+
logger.info("setting sync_param")
11751181
strategy.hybrid_configs["mp_configs"].sync_param = True
1176-
strategy.hybrid_configs["mp_configs"].sync_param_name = [""]
1182+
11771183
if sync_grad:
1184+
logger.info("setting sync_grad")
11781185
strategy.hybrid_configs["mp_configs"].sync_grad = True
1179-
strategy.hybrid_configs["mp_configs"].sync_grad_name = [""]
1186+
11801187
if sync_moment:
1188+
logger.info("setting sync_moment")
11811189
strategy.hybrid_configs["mp_configs"].sync_moment = True
1182-
strategy.hybrid_configs["mp_configs"].sync_moment_name = [""]
1190+
11831191
except:
11841192
warnings.warn(
11851193
"The enable_mp_async_allreduce, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add are not supported "

paddlenlp/trainer/utils/reshard/sharding_v2.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414

1515
import numpy as np
1616
import paddle
17+
import paddle.distributed.fleet as fleet
1718
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import (
1819
HybridParallelOptimizer,
1920
)
2021
from paddle.distributed.fleet.model import PipelineParallel
2122

23+
from paddlenlp.utils.log import logger
24+
2225
from ....transformers.model_utils import unwrap_optimizer
2326

2427
try:
@@ -29,6 +32,9 @@
2932
DygraphShardingOptimizerV2 = None
3033

3134

35+
from paddle.distributed.communication.reduce import ReduceOp
36+
37+
3238
def shard(node_model_state, model, optimizer, hcg):
3339
assert DygraphShardingOptimizerV2 is not None
3440
group = hcg.get_sharding_parallel_group()
@@ -137,7 +143,7 @@ def slice_tensor(tensor, begin, end):
137143
return tensor[begin:end]
138144

139145

140-
def collect_split_info(optimizer, model):
146+
def collect_split_info(optimizer, model, only_return_lengths=False):
141147
split_infos = {}
142148

143149
def gather_infos(comm_buffer):
@@ -146,7 +152,13 @@ def gather_infos(comm_buffer):
146152
padded_size = v._padded_size
147153
buffer_size = v._param_buffer._numel()
148154
has_slice_grad = v._slice_grad is not None
149-
split_infos[k] = (index, padded_size, buffer_size, has_slice_grad)
155+
if only_return_lengths:
156+
if v._param_begin < v._param_end:
157+
split_infos[k] = v._param_end - v._param_begin
158+
else:
159+
split_infos[k] = None
160+
else:
161+
split_infos[k] = (index, padded_size, buffer_size, has_slice_grad)
150162

151163
if isinstance(model, PipelineParallel) and model._sharding_comm_overlap > 0:
152164
optimizer = unwrap_optimizer(optimizer, HybridParallelOptimizer)
@@ -167,6 +179,51 @@ def gather_infos(comm_buffer):
167179
return split_infos
168180

169181

182+
def is_matched_optimizer_state_dict(opt_state_dict, optimizer, model, hcg=None, need_allgather=True):
183+
split_infos = collect_split_info(optimizer, model, only_return_lengths=True)
184+
master_weights = opt_state_dict.get("master_weights", None)
185+
186+
def get_matched_length(name):
187+
if master_weights and name in master_weights:
188+
tensor = master_weights[name]
189+
else:
190+
moment_name = name + "_moment1_0"
191+
if moment_name not in opt_state_dict:
192+
return None
193+
194+
tensor = opt_state_dict[moment_name]
195+
if isinstance(tensor, (list, tuple)):
196+
assert len(tensor) == 2, tensor
197+
assert isinstance(tensor[0], str), tensor[0]
198+
tensor = tensor[1]
199+
shape = tensor.shape
200+
assert len(shape) == 1, shape
201+
length = shape[0]
202+
return length
203+
204+
is_matched = 1
205+
for k, length in split_infos.items():
206+
matched_length = get_matched_length(k)
207+
if length != matched_length:
208+
is_matched = 0
209+
break
210+
211+
if need_allgather:
212+
if hcg is None:
213+
hcg = fleet.get_hybrid_communicate_group()
214+
group = hcg.get_sharding_parallel_group()
215+
if group is not None and group.nranks > 1:
216+
x = paddle.to_tensor([is_matched], dtype=paddle.int32)
217+
paddle.distributed.stream.all_reduce(x, op=ReduceOp.MIN, group=group, sync_op=True, use_calc_stream=True)
218+
global_is_matched = int(x.numpy()[0])
219+
else:
220+
global_is_matched = is_matched
221+
222+
global_is_matched = True if global_is_matched else False
223+
logger.info(f"Sharding reshard checkpoint: local_match = {is_matched} , global_match = {global_is_matched}")
224+
return global_is_matched
225+
226+
170227
def is_bata(name):
171228
if "_beta1_pow_acc_" in name:
172229
return True

paddlenlp/trainer/utils/sharding_io.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from paddlenlp.utils.log import logger
4141

4242
from . import reshard as reshard_util
43-
from .reshard import SHARDING_STRATEGY_V1, pp_reshard
43+
from .reshard import SHARDING_STRATEGY_V1, SHARDING_STRATEGY_V2, pp_reshard
4444

4545
# Name of the files used for checkpointing
4646
TRAINING_ARGS_NAME = "training_args.bin"
@@ -204,10 +204,21 @@ def _load_optimizer_state_of_one_shard(self, checkpoint, base_opt_name, optimize
204204
path = os.path.join(checkpoint, optimizer_name)
205205
logger.info(f"load optimizer state from {path}")
206206
if os.path.isfile(path):
207-
return paddlenlp_load(path, map_location="cpu")
207+
return self._modify_ckpt_for_compatibility(paddlenlp_load(path, map_location="cpu"))
208208
logger.info(f"{path} not exists")
209209
return None
210210

211+
def _modify_ckpt_for_compatibility(self, ckpt):
212+
master_weights = ckpt.get("master_weights", None)
213+
if master_weights:
214+
for k, v in master_weights.items():
215+
assert isinstance(v, paddle.Tensor), v
216+
if not v.name.startswith(k):
217+
new_name = k + "_fp32_master_0"
218+
logger.info(f"Modify master weights {v.name} -> {new_name}")
219+
v.name = new_name
220+
return ckpt
221+
211222
def _need_reshard(self, checkpoint):
212223
if self._need_reshard_pp(checkpoint):
213224
return True
@@ -253,10 +264,6 @@ def _need_reshard_pp(self, checkpoint):
253264
def load_optimizer_state_with_reshard(self, checkpoint, base_opt_name, model_wrapped):
254265
"""load state_dict of multiple shard from_checkpoint, Only load model state dict."""
255266

256-
if not self._need_reshard(checkpoint):
257-
logger.info("do not need reshard")
258-
return self._load_optimizer_state_of_one_shard(checkpoint, base_opt_name, self.args.optimizer_name_suffix)
259-
logger.info("reshard optimizer state")
260267
parallel_config = self._load_distributed_strategy(checkpoint)
261268
sharding_meta = self._load_sharding_meta(checkpoint, 0)
262269
pp_degree = parallel_config["pp_degree"]
@@ -276,16 +283,41 @@ def load_optimizer_state_with_reshard(self, checkpoint, base_opt_name, model_wra
276283
cur_sharding_degree = self.args.sharding_parallel_degree
277284
cur_sharding_strategy = reshard_util.get_sharding_strategy(self.optimizer)
278285

286+
if not self._need_reshard(checkpoint):
287+
one_shard_opt_state_dict = self._load_optimizer_state_of_one_shard(
288+
checkpoint, base_opt_name, self.args.optimizer_name_suffix
289+
)
290+
291+
if sharding_strategy == SHARDING_STRATEGY_V2 and cur_sharding_strategy == SHARDING_STRATEGY_V2:
292+
is_matched = reshard_util.sharding_v2.is_matched_optimizer_state_dict(
293+
one_shard_opt_state_dict, self.optimizer, model_wrapped
294+
)
295+
else:
296+
is_matched = True
297+
298+
if is_matched:
299+
logger.info("do not need reshard")
300+
return one_shard_opt_state_dict
301+
else:
302+
one_shard_opt_state_dict = None
303+
304+
logger.info("reshard optimizer state")
305+
279306
def load_model_slices():
280307
model_state = reshard_util.NodeModelState()
281308
for j in range(self.args.pipeline_parallel_rank, pp_degree, cur_pp_degree):
282309
cur_sharding_meta = self._load_sharding_meta(checkpoint, j)
283310
assert "structure_name_mapping" in cur_sharding_meta
284311
structure_name_map = cur_sharding_meta["structure_name_mapping"]
285312
for i in range(self.args.sharding_parallel_rank, sharding_degree, cur_sharding_degree):
286-
tmp = self._load_optimizer_state_of_one_shard(
287-
checkpoint, base_opt_name, self.args.sharded_name_suffix(i, j)
288-
)
313+
sharded_name_suffix = self.args.sharded_name_suffix(i, j)
314+
if one_shard_opt_state_dict is None:
315+
tmp = self._load_optimizer_state_of_one_shard(checkpoint, base_opt_name, sharded_name_suffix)
316+
else:
317+
assert (
318+
self.args.optimizer_name_suffix == sharded_name_suffix
319+
), f"{self.args.optimizer_name_suffix} vs {sharded_name_suffix}"
320+
tmp = one_shard_opt_state_dict
289321
node_model_state_tmp = reshard_util.NodeModelState()
290322
node_model_state_tmp.add_opts(tmp)
291323
node_model_state_tmp.pack_keys(structure_name_map)

0 commit comments

Comments
 (0)