Skip to content

Conversation

@ZHUI
Copy link
Contributor

@ZHUI ZHUI commented Sep 22, 2022

PR types

New features

PR changes

APIs

Description

support sharding for trainer.

stage1: 可以支持

stage2:部分支持

  • offload 暂不支持,需要修复pure_fp16

stage3:暂不支持

  • 模型保存存在问题

@ZHUI ZHUI marked this pull request as ready for review September 23, 2022 10:03
fused_allreduce_gradients(list(model.parameters()),
None)

if self.do_grad_scaling:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@haohongxiang 这里的scaler使用体验,请与官方scaler一致。


self.save_model(output_dir)

if self.sharding is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@haohongxiang 提供接口,rank 0卡,收集参数到cpu

Comment on lines 683 to 684
if self.do_grad_scaling:
self.scaler.minimize(self.optimizer, tr_loss)
# TODO: fix sharding stage2 stage3 with original scaler useage.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处的api使用上有问题

@ZHUI ZHUI requested a review from gongweibao October 9, 2022 09:43

The value of initial scale_loss for fp16. (default: 32768)

--sharding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以写一下NOTICE,目前可用的状态

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已补充

|--------------------------------|-------|-------|-------------|------------------|-------------|-------------|------|-------|
| | mcc | acc | acc | pearson | acc | acc | acc | acc |
| T5-v1_1-base Paddle | 47.6845 | 94.38 | 84.31 | 87.74 | 88.05 | 85.39 | 90.518 | 65.70 |
| epoch | 100 | 10 | 100 | 100 | 3 | 3 | 10 | 100 |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T5_v1_1_base效果对齐了吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如线下沟通

if self.label2id:
label = self.label2id[label]
if pred not in self.label2id:
pred = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么label为0的时候,pred = 1?
这里的逻辑可以再具体说一下

生成的label不在label list里面,最终预付label 0,这块的情况在encoder不会出现,这里能具体说一下数据指标怎么对齐了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已补充注释

@ZHUI ZHUI requested a review from wawltor October 24, 2022 11:44
haohongxiang
haohongxiang previously approved these changes Oct 25, 2022
Copy link
Contributor

@haohongxiang haohongxiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for sharding+dp

next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1)
next_scores = paddle.index_sample(probs, next_tokens)
next_scores = paddle.index_sample(probs.astype("float32"),
next_tokens)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的话,index_sample 没有 fp16/bf16 kernel

f"{self.dtype} not recognized. `dtype` should be set to either `paddle.float32` or `paddle.float16`"
)
encoder_extended_attention_mask = (
1.0 - encoder_extended_attention_mask) * -1e4
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For bf16 dtype

labels.flatten())
loss = loss_fct(
lm_logits.reshape(
shape=[-1, lm_logits.shape[-1]]).astype("float32"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CrossEntropyLoss 没有fp16/bf16 kernel

wawltor
wawltor previously approved these changes Nov 14, 2022
Copy link
Contributor

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZHUI ZHUI merged commit b35b8d6 into PaddlePaddle:develop Nov 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants