Skip to content

Commit 14148f2

Browse files
authored
Merge branch 'develop' into update_distloader
2 parents 7ea22e8 + d6ac1bd commit 14148f2

File tree

15 files changed

+520
-53
lines changed

15 files changed

+520
-53
lines changed

model_zoo/bert/run_pretrain.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,10 @@ def do_train(args):
424424
optimizer.step()
425425
lr_scheduler.step()
426426
optimizer.clear_grad()
427+
428+
# NOTE: For accurate data statistics, please open the comments below,especially when args.logging_steps==1.
429+
# if global_step % args.logging_steps == 0:
430+
# loss = loss.numpy()
427431
total_samples += args.batch_size
428432
train_run_cost = time.time() - batch_start
429433
train_cost_avg.record(train_run_cost)

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from paddlenlp.transformers.model_utils import (
3131
PretrainedModel,
3232
_load_state_dict_into_model,
33+
faster_set_state_dict,
3334
get_parameter_dtype,
3435
load_state_dict,
3536
unwrap_model,
@@ -65,9 +66,10 @@
6566
from paddlenlp.utils.nested import nested_copy, nested_copy_place
6667

6768
if is_safetensors_available():
68-
from safetensors import safe_open
69+
# from safetensors import safe_open
6970
from safetensors.numpy import save_file as safe_save_file
7071

72+
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
7173

7274
FP32_MASTER = "fp32_master_0"
7375
optimizer_scalar_name = [
@@ -196,7 +198,6 @@ def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str,
196198
Returns:
197199
None
198200
"""
199-
200201
if paddle.distributed.get_world_size() <= 1:
201202
load_single_card_checkpoint(args, model, resume_from_checkpoint)
202203
return
@@ -222,7 +223,6 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa
222223
pretrained_model_name_or_path=resume_from_checkpoint,
223224
index_filename=os.path.join(resume_from_checkpoint, index_filename),
224225
)
225-
226226
loaded_keys = sharded_metadata["all_checkpoint_keys"]
227227

228228
model_state_dict = get_expected_state_dict(model)
@@ -266,7 +266,9 @@ def _remove_unused_keys(
266266
else:
267267
tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True)
268268
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
269-
state_dict = load_state_dict(shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys)
269+
state_dict = load_state_dict(
270+
shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys, device="expected"
271+
)
270272

271273
if not pre_tensor_parallel_split:
272274
# Since we load all keys but we only need one of pipeline stages
@@ -279,11 +281,12 @@ def _remove_unused_keys(
279281
None, model.config, state_dict=state_dict, ignore_error=len(resolved_archive_file) > 1
280282
)
281283

282-
error_msgs += _load_state_dict_into_model(model, state_dict, "")
284+
# error_msgs += _load_state_dict_into_model(model, state_dict, "")
285+
error_msgs += faster_set_state_dict(model, state_dict, strict_dtype=False)
283286

284287
# force memory release
285288
del state_dict
286-
gc.collect()
289+
# gc.collect()
287290

288291
if len(error_msgs) > 0:
289292
error_msg = "\n\t".join(error_msgs)
@@ -337,6 +340,7 @@ def unified_checkpoint_into_shards(
337340
tp_actions = model_to_save.get_tensor_parallel_convert_actions(
338341
model_to_save.config, state_dict.keys(), is_split=False, ignore_error=True
339342
)
343+
logger.info("Unified model tensor parallel weights in shards")
340344
state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys)
341345

342346
# build index json file
@@ -490,6 +494,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
490494
# This should always be a list but, just to be sure.
491495
if not isinstance(resolved_archive_file, list):
492496
resolved_archive_file = [resolved_archive_file]
497+
493498
if len(resolved_archive_file) > 1:
494499
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")
495500

@@ -537,10 +542,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
537542
tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys)
538543

539544
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
540-
state_dict = load_state_dict(shard_file, tp_actions, expected_keys)
545+
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected")
541546
else:
542547
# for pipeline model, we don't need to use tp_actions
543-
state_dict = load_state_dict(shard_file, None, expected_keys)
548+
state_dict = load_state_dict(shard_file, None, expected_keys, device="expected")
544549

545550
returned_state_dict.update(state_dict)
546551
# force memory release
@@ -553,7 +558,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
553558
state_dict_master_weight = load_resolved_archive_file(
554559
resolved_archive_file_mw, sharded_metadata_mw, expected_keys_mw, is_master_weights=True
555560
)
556-
557561
# rename optimizer param
558562
for key in list(state_dict_optim.keys()):
559563
key_name = key.split("/")
@@ -562,13 +566,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
562566
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
563567
else:
564568
key_name = "_".join([static_name, key_name[1]])
565-
returned_optim_state_dict[key_name] = state_dict_optim[key]
569+
returned_optim_state_dict[key_name] = state_dict_optim.pop(key)
566570
returned_optim_state_dict[key_name].name = key_name
567571

568572
if has_master_weights:
569573
for key in list(state_dict_master_weight.keys()):
570574
static_name = struct2static_name_mappings[key]
571-
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight[key]
575+
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)
572576
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
573577

574578
returned_optim_state_dict = nested_copy_place(
@@ -640,6 +644,7 @@ def unified_optimizer_into_shards(
640644
tp_actions = model.get_tensor_parallel_convert_actions(
641645
model.config, model_keys, is_split=False, ignore_error=True
642646
)
647+
logger.info("Unified optimizer tensor parallel in shards")
643648
optim_state_dict = merge_tensor_parallel_for_optimizer(
644649
optim_state_dict,
645650
tp_actions,
@@ -648,6 +653,7 @@ def unified_optimizer_into_shards(
648653
paddle.device.cuda.empty_cache()
649654

650655
if master_weights is not None:
656+
logger.info("Unified master weight tensor parallel in shards")
651657
master_weights = merge_tensor_parallel_for_optimizer(
652658
master_weights,
653659
tp_actions,
@@ -703,7 +709,6 @@ def unified_optimizer_into_shards(
703709
def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False):
704710
index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=False)
705711
index_filename = os.path.join(resume_from_checkpoint, index_filename)
706-
707712
# Find index json file and distribute this file in global group.
708713
if distributed_isfile(index_filename):
709714
distributed_file(index_filename)
@@ -1605,7 +1610,9 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False):
16051610
tp_group = hcg.get_model_parallel_group()
16061611
pp_group = hcg.get_pipe_parallel_group()
16071612

1608-
logger.info("Unified checkpoint generating sharded_index json files.")
1613+
logger.info(
1614+
f"Unified checkpoint: generating sharded_index json files for {'optimizer or master weight' if is_optimizer else 'model weight'}."
1615+
)
16091616

16101617
if tp_group.nranks > 1:
16111618
dist.all_gather_object(index_file_list, index_file, tp_group)
@@ -1714,8 +1721,6 @@ def filter_params(model_to_save, state_dict, is_optimizer=False):
17141721

17151722

17161723
def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
1717-
logger.info("Unified checkpoint merge tensor parallel in shards")
1718-
17191724
hcg = fleet.get_hybrid_communicate_group()
17201725
tp_group = hcg.get_model_parallel_group()
17211726
tp_rank = tp_group.rank
@@ -1741,7 +1746,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17411746
action = tp_actions.pop(key)
17421747
tensor = action(ret) if is_dst else None
17431748
else:
1744-
tensor = tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
1749+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None
17451750

17461751
if is_dst:
17471752
state_dict_to_save[key] = tensor
@@ -1754,8 +1759,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17541759

17551760

17561761
def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys):
1757-
logger.info("Unified optimizer tensor parallel in shards")
1758-
1762+
# Core function for UC
17591763
hcg = fleet.get_hybrid_communicate_group()
17601764
tp_group = hcg.get_model_parallel_group()
17611765
tp_rank = tp_group.rank
@@ -1774,14 +1778,14 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
17741778
# for example: beta1, beta2
17751779
if tensor.numel().item() == 1:
17761780
tensor = (
1777-
tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
1781+
tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None
17781782
) # Need broadcast when loaded
17791783
else:
17801784
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
17811785
action = tp_actions[model_key]
17821786
tensor = action(ret) if is_dst else None
17831787
else:
1784-
tensor = tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
1788+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None
17851789

17861790
if is_dst:
17871791
state_dict_to_save[filter_keys[i]] = tensor

paddlenlp/trainer/trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,6 +2436,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24362436
self.runtime_timer.stop()
24372437
return
24382438

2439+
logger.info("Loading optimizer and scheduler...")
24392440
if (not self.args.should_load_sharding_stage1_model) and self.args.ignore_load_lr_and_optim:
24402441
self.runtime_timer.stop()
24412442
return
@@ -2765,11 +2766,15 @@ def evaluation_loop(
27652766
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
27662767
# samplers has been rounded to a multiple of batch_size, so we truncate.
27672768
if all_losses is not None:
2768-
all_losses = all_losses[:num_samples]
2769+
all_losses = all_losses[: num_samples * int(self.args.world_size / self.args.dataset_world_size)]
27692770
if all_preds is not None:
2770-
all_preds = nested_truncate(all_preds, num_samples)
2771+
all_preds = nested_truncate(
2772+
all_preds, num_samples * int(self.args.world_size / self.args.dataset_world_size)
2773+
)
27712774
if all_labels is not None:
2772-
all_labels = nested_truncate(all_labels, num_samples)
2775+
all_labels = nested_truncate(
2776+
all_labels, num_samples * int(self.args.world_size / self.args.dataset_world_size)
2777+
)
27732778

27742779
model.train()
27752780

paddlenlp/transformers/conversion_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,12 @@ def naive_fuse_merge_tp(weight_list, is_column=True, fuse_tensor_parts=2):
285285

286286
if isinstance(weight_list[0], np.ndarray):
287287
return np.concatenate([reorder[i] for i in index], axis=axis)
288+
else:
289+
tensor = paddle.concat([reorder[i] for i in index], axis=axis)
288290

289-
return paddle.concat([reorder[i] for i in index], axis=axis)._copy_to(paddle.CPUPlace(), False)
291+
if tensor.place.is_gpu_place():
292+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
293+
return tensor
290294

291295

292296
def naive_fuse_split_tp(
@@ -361,12 +365,18 @@ def normal_fuse_merge_tp(weight_list, is_column=True):
361365
if isinstance(weight_list[0], np.ndarray):
362366
return np.concatenate(weight_list, axis=-1)
363367
else:
364-
return paddle.concat(weight_list, axis=-1)._copy_to(paddle.CPUPlace(), False)
368+
tensor = paddle.concat(weight_list, axis=-1)
369+
if tensor.place.is_gpu_place():
370+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
371+
return tensor
365372
else:
366373
if isinstance(weight_list[0], np.ndarray):
367374
return np.concatenate(weight_list, axis=0)
368375
else:
369-
return paddle.concat(weight_list, axis=0)._copy_to(paddle.CPUPlace(), False)
376+
tensor = paddle.concat(weight_list, axis=0)
377+
if tensor.place.is_gpu_place():
378+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
379+
return tensor
370380

371381

372382
def normal_fuse_split_tp(weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True):

paddlenlp/transformers/llama/modeling.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,11 @@ def swiglu(x, y=None):
9696
"LlamaForCausalLM",
9797
"LlamaPretrainingCriterion",
9898
]
99-
global npu_is_casual
99+
100+
100101
npu_is_casual = False
101102

103+
102104
def _get_interleave(n):
103105
def _get_interleave_power_of_2(n):
104106
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
@@ -213,7 +215,7 @@ def scaled_dot_product_attention(
213215
):
214216
bsz, q_len, num_heads, head_dim = query_states.shape
215217
_, kv_seq_len, _, _ = value_states.shape
216-
global npu_is_casual
218+
217219
if config.use_flash_attention and flash_attention:
218220
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
219221
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
@@ -1119,7 +1121,6 @@ def __init__(self, config, layerwise_recompute: bool = False):
11191121
self.layerwise_recompute = layerwise_recompute
11201122
self.recompute_granularity = config.recompute_granularity
11211123

1122-
11231124
def forward(
11241125
self,
11251126
hidden_states: paddle.Tensor,
@@ -1613,14 +1614,12 @@ def forward(
16131614
attention_mask = self._prepare_decoder_attention_mask(
16141615
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
16151616
) # [bs, 1, seq_len, seq_len]
1616-
global npu_is_casual
16171617
if self.config.use_flash_attention:
16181618
is_casual = is_casual_mask(attention_mask)
16191619
if get_env_device() != "npu":
16201620
if is_casual and alibi is None:
16211621
attention_mask = None
16221622
else:
1623-
npu_is_casual = is_casual
16241623
attention_mask = attention_mask.astype("bool")
16251624
hidden_states = inputs_embeds
16261625
# decoder layers
@@ -1728,10 +1727,12 @@ def forward(self, prediction_scores, masked_lm_labels):
17281727
# skip ignore_index which loss == 0
17291728
# masked_lm_loss = masked_lm_loss[masked_lm_loss > 0]
17301729
# loss = paddle.mean(masked_lm_loss)
1731-
binary_sequence = paddle.where(masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss))
1730+
binary_sequence = paddle.where(
1731+
masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss)
1732+
)
17321733
sum_ = paddle.sum(binary_sequence)
1733-
loss = 0 if sum_ == 0 else paddle.sum(masked_lm_loss * binary_sequence) / sum_
1734-
1734+
loss = 0 if sum_ == 0 else paddle.sum(masked_lm_loss * binary_sequence) / sum_
1735+
17351736
return loss
17361737

17371738

0 commit comments

Comments
 (0)