Skip to content

Commit dd376ce

Browse files
authored
[ModelZoo] Refactor ERNIE-M usage in Model Zoo (#4324)
* update model_zoo/ernie-m * add logging eval infomations * delete unused testing code * add tests samples for xnli * fewer tests samples for xnli * add predict and export * add predictor and serving * fix export error when training with data parallelism * fix mkdir error in distributed training and use ERNIEMHandler * adjust predictor * modify ci accroding to #4398 and adjust predictor * enable load tiny dataset for ci * support testing infer with precision_mode fp16 * modify ci script * Update ci_case.sh * enable to test inputs_embeds for enire-m * fix ci script * fix using fast tokenizer * using set_optimizer_grouped_parameters instead * consturct optimizer with layerwise_lr_decay out of Trainer
1 parent a776002 commit dd376ce

File tree

15 files changed

+1134
-266
lines changed

15 files changed

+1134
-266
lines changed

model_zoo/ernie-m/README.md

Lines changed: 129 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
# ERNIE-M
22

3-
* [模型简介](#模型简介)
4-
* [快速开始](#快速开始)
5-
* [通用参数释义](#通用参数释义)
6-
* [自然语言推断任务](#自然语言推断任务)
3+
* [模型介绍](#模型介绍)
4+
* [开始运行](#开始运行)
5+
* [环境要求](#环境要求)
6+
* [数据准备](#数据准备)
7+
* [模型训练](#模型训练)
8+
* [参数释义](#参数释义)
9+
* [单卡训练](#单卡训练)
10+
* [单机多卡](#单机多卡)
11+
* [预测评估](#预测评估)
12+
* [部署](#部署)
13+
* [Python部署](#Python部署)
14+
* [服务化部署](#服务化部署)
715
* [参考论文](#参考论文)
816

9-
## 模型简介
17+
## 模型介绍
1018

1119
[ERNIE-M](https://arxiv.org/abs/2012.15674) 是百度提出的一种多语言语言模型。原文提出了一种新的训练方法,让模型能够将多种语言的表示与单语语料库对齐,以克服平行语料库大小对模型性能的限制。原文的主要想法是将回译机制整合到预训练的流程中,在单语语料库上生成伪平行句对,以便学习不同语言之间的语义对齐,从而增强跨语言模型的语义建模。实验结果表明,ERNIE-M 优于现有的跨语言模型,并在各种跨语言下游任务中提供了最新的 SOTA 结果。
1220
原文提出两种方法建模各种语言间的对齐关系:
@@ -17,67 +25,145 @@
1725

1826
![framework](https://user-images.githubusercontent.com/40912707/201308423-bf4f0100-3ada-4bae-89d5-b07ffec1e2c0.png)
1927

20-
本项目是 ERNIE-M 的 PaddlePaddle 动态图实现, 包含模型训练,模型验证等内容。以下是本例的简要目录结构及说明:
28+
本项目是 ERNIE-M 的 PaddlePaddle 动态图实现,包含模型训练,模型验证等内容。以下是本例的简要目录结构及说明:
2129

2230
```text
2331
.
2432
├── README.md # 文档
2533
├── run_classifier.py # 自然语言推断任务
2634
```
2735

28-
## 快速开始
36+
## 开始运行
2937

30-
### 通用参数释义
38+
下面提供以XNLI数据集进行模型微调相关训练、预测、部署的代码,XNLI数据集是MNLI的子集,并且已被翻译成14种不同的语言(包含一些较低资源语言)。与MNLI一样,目标是预测文本蕴含(句子 A 是否暗示/矛盾/都不是句子 B )。
39+
40+
### 环境要求
41+
42+
python >= 3.7
43+
paddlepaddle >= 2.3
44+
paddlenlp >= 2.4.9
45+
paddle2onnx >= 1.0.5
46+
47+
### 数据准备
48+
49+
此次微调数据使用XNLI数据集, 可以通过下面的方式来使用数据集
50+
51+
```python
52+
from datasets import load_dataset
53+
54+
# all_languages = ["ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr", "ur", "vi", "zh"]
55+
# load xnli dataset of english
56+
train_ds, eval_ds, test_ds = load_dataset("xnli", "en", split=["train_ds", "validation", "test"])
57+
```
58+
59+
### 模型训练
60+
61+
#### 参数释义
3162

3263
- `task_type` 表示了自然语言推断任务的类型,目前支持的类型为:"cross-lingual-transfer", "translate-train-all"
3364
,分别表示在英文数据集上训练并在所有15种语言数据集上测试、在所有15种语言数据集上训练和测试。
3465
- `model_name_or_path` 指示了 Fine-tuning 使用的具体预训练模型以及预训练时使用的tokenizer,目前支持的预训练模型有:"ernie-m-base", "ernie-m-large"
35-
。若模型相关内容保存在本地,这里也可以提供相应目录地址,例如:"./checkpoint/model_xx/"。
66+
。若模型相关内容保存在本地,这里也可以提供相应目录地址,例如:"./finetuned_models"。
67+
- `do_train` 是否进行训练任务。
68+
- `do_eval` 是否进行评估任务。
69+
- `do_predict` 是否进行评测任务。
70+
- `do_export` 是否导出模型。
3671
- `output_dir` 表示模型保存路径。
72+
- `export_model_dir` 模型的导出路径。
73+
- `per_device_train_batch_size` 表示训练时每次迭代**每张**卡上的样本数目。
74+
- `per_device_eval_batch_size` 表示验证时每次迭代**每张**卡上的样本数目。
3775
- `max_seq_length` 表示最大句子长度,超过该长度将被截断,不足该长度的将会进行 padding。
38-
- `memory_length` 表示当前的句子被截取作为下一个样本的特征的长度。
3976
- `learning_rate` 表示基础学习率大小,将于 learning rate scheduler 产生的值相乘作为当前学习率。
77+
- `classifier_dropout` 表示模型用于分类的 dropout rate ,默认是0.1。
4078
- `num_train_epochs` 表示训练轮数。
4179
- `logging_steps` 表示日志打印间隔步数。
4280
- `save_steps` 表示模型保存及评估间隔步数。
43-
- `batch_size` 表示每次迭代**每张**卡上的样本数目。
44-
- `weight_decay` 表示AdamW的权重衰减系数。
4581
- `layerwise_decay` 表示 AdamW with Layerwise decay 的逐层衰减系数。
46-
- `adam_epsilon` 表示AdamW优化器的 epsilon。
47-
- `warmup_proportion` 表示学习率warmup系数。
82+
- `warmup_rate` 表示学习率warmup系数。
4883
- `max_steps` 表示最大训练步数。若训练`num_train_epochs`轮包含的训练步数大于该值,则达到`max_steps`后就提前结束。
4984
- `seed` 表示随机数种子。
5085
- `device` 表示训练使用的设备, 'gpu'表示使用 GPU, 'xpu'表示使用百度昆仑卡, 'cpu'表示使用 CPU。
51-
- `use_amp` 表示是否启用自动混合精度训练。
86+
- `fp16` 表示是否启用自动混合精度训练。
5287
- `scale_loss` 表示自动混合精度训练的参数。
53-
54-
### 自然语言推断任务
55-
56-
#### 数据集介绍
57-
XNLI 是 MNLI 的子集,并且已被翻译成14种不同的语言(包含一些较低资源语言)。与 MNLI 一样,目标是预测文本蕴含(句子 A 是否暗示/矛盾/都不是句子 B )。
88+
- `load_best_model_at_end` 训练结束后是否加载最优模型,通常与`metric_for_best_model`配合使用。
89+
- `metric_for_best_model` 最优模型指标,如`eval_accuarcy`等,用于比较模型好坏。
5890

5991
#### 单卡训练
6092

93+
`run_classifier.py`是模型微调脚本,可以使用如下命令对预训练模型进行微调训练。
94+
6195
```shell
6296
python run_classifier.py \
63-
--task_type cross-lingual-transfer \
64-
--batch_size 16 \
65-
--model_name_or_path ernie-m-base \
66-
--save_steps 12272 \
67-
--output_dir output
97+
--do_train \
98+
--do_eval \
99+
--do_export \
100+
--task_type cross-lingual-transfer \
101+
--model_name_or_path ernie-m-base \
102+
--output_dir ./finetuned_models/ \
103+
--export_model_dir ./finetuned_models/ \
104+
--per_device_train_batch_size 16 \
105+
--per_device_eval_batch_size 16 \
106+
--max_seq_length 256 \
107+
--learning_rate 5e-5 \
108+
--classifier_dropout 0.1 \
109+
--weight_decay 0.0 \
110+
--layerwise_decay 0.8 \
111+
--save_steps 12272 \
112+
--eval_steps 767 \
113+
--num_train_epochs 5 \
114+
--warmup_ratio 0.1 \
115+
--load_best_model_at_end True \
116+
--metric_for_best_model eval_accuracy \
117+
--overwrite_output_dir
118+
```
119+
120+
#### 单机多卡
121+
122+
同样,可以执行如下命令实现多卡训练
123+
124+
```shell
125+
python -m paddle.distributed.launch --gpus 0,1 run_classifier.py \
126+
--do_train \
127+
--do_eval \
128+
--do_export \
129+
--task_type cross-lingual-transfer \
130+
--model_name_or_path ernie-m-base \
131+
--output_dir ./finetuned_models/ \
132+
--export_model_dir ./finetuned_models/ \
133+
--per_device_train_batch_size 16 \
134+
--per_device_eval_batch_size 16 \
135+
--max_seq_length 256 \
136+
--learning_rate 5e-5 \
137+
--classifier_dropout 0.1 \
138+
--weight_decay 0.0 \
139+
--layerwise_decay 0.8 \
140+
--save_steps 12272 \
141+
--eval_steps 767 \
142+
--num_train_epochs 5 \
143+
--warmup_ratio 0.1 \
144+
--load_best_model_at_end True \
145+
--metric_for_best_model eval_accuracy \
146+
--overwrite_output_dir \
147+
--remove_unused_columns False
68148
```
69149

70-
#### 多卡训练
150+
这里设置额外的参数`--remove_unused_columns``False`是因为数据集中不需要的字段已经被手动去除了。
151+
152+
#### 预测评估
153+
154+
当训练完成后,可以直接加载训练保存的模型进行评估,此时`--model_name_or_path`传入训练时的`output_dir``./finetuned_models`
71155

72156
```shell
73-
python -m paddle.distributed.launch --gpus 0,1 --log_dir output run_classifier.py \
157+
python run_classifier.py \
158+
--do_predict \
74159
--task_type cross-lingual-transfer \
75-
--batch_size 16 \
76-
--model_name_or_path ernie-m-base \
77-
--save_steps 12272 \
78-
--output_dir output
160+
--model_name_or_path ./finetuned_models \
161+
--output_dir ./finetuned_models
79162
```
80163

164+
预测结果(label)和预测的置信度(confidence)将写入`./finetuned_models/test_results.json`文件。
165+
166+
81167
在XNLI数据集上微调 cross-lingual-transfer 类型的自然语言推断任务后,在测试集上有如下结果
82168
| Model | en | fr | es | de | el | bg | ru | tr | ar | vi | th | zh | hi | sw | ur | Avg |
83169
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
@@ -100,6 +186,18 @@ python -m paddle.distributed.launch --gpus 0,1 --log_dir output run_classifier.p
100186
| XLM-R Large | 89.1 | 85.1 | 86.6 | 85.7 | 85.3 | 85.9 | 83.5 | 83.2 | 83.1 | 83.7 | 81.5 | **83.7** | **81.6** | 78.0 | 78.1 | 83.6 |
101187
| VECO Large | 88.9 | 82.4 | 86.0 | 84.7 | 85.3 | 86.2 | **85.8** | 80.1 | 83.0 | 77.2 | 80.9 | 82.8 | 75.3 | **83.1** | **83.0** | 83.0 |
102188
| **ERNIE-M Large** | **89.5** | **86.5** | **86.9** | **86.1** | **86.0** | **86.8** | 84.1 | **83.8** | **84.1** | **84.5** | **82.1** | 83.5 | 81.1 | 79.4 | 77.9 | **84.2** |
189+
190+
## 部署
191+
192+
### Python部署
193+
194+
Python部署请参考:[Python 部署指南](./deploy/predictor/README.md)
195+
196+
### 服务化部署
197+
198+
* [PaddleNLp SimpleServing 服务化部署指南](./deploy/simple_serving/README.md)
199+
200+
103201
## 参考论文
104202

105203
[Ouyang X , Wang S , Pang C , et al. ERNIE-M: Enhanced Multilingual Representation by Aligning Cross-lingual Semantics with Monolingual Corpora[J]. 2020.](https://arxiv.org/abs/2012.15674)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# ERNIEM Python部署指南
2+
本文介绍 ERNIE 3.0 Python 端的部署,包括部署环境的准备,序列标注和分类两大场景下的使用示例。
3+
- [ERNIE-M Python 部署指南](#ERNIEM-Python部署指南)
4+
- [1. 环境准备](#1-环境准备)
5+
- [1.1 CPU 端](#11-CPU端)
6+
- [1.2 GPU 端](#12-GPU端)
7+
- [2. 分类模型推理](#2-分类模型推理)
8+
- [2.1 模型获取](#21-模型获取)
9+
- [2.2 CPU 端推理样例](#22-CPU端推理样例)
10+
- [2.3 GPU 端推理样例](#23-GPU端推理样例)
11+
## 1. 环境准备
12+
ERNIE-M 的部署分为 CPU 和 GPU 两种情况,请根据你的部署环境安装对应的依赖。
13+
### 1.1 CPU端
14+
CPU 端的部署请使用如下命令安装所需依赖
15+
```
16+
pip install -r requirements_cpu.txt
17+
```
18+
### 1.2 GPU端
19+
为了在 GPU 上获得最佳的推理性能和稳定性,请先确保机器已正确安装 NVIDIA 相关驱动和基础软件,确保 CUDA >= 11.2,CuDNN >= 8.2,并使用以下命令安装所需依赖
20+
```
21+
pip install -r requirements_gpu.txt
22+
```
23+
24+
25+
## 2. 模型推理
26+
### 2.1 模型获取
27+
用户可使用自己训练的模型进行推理,具体训练调优方法可参考[模型训练调优](./../../README.md#模型训练)
28+
29+
### 2.2 CPU端推理样例
30+
在 CPU 端,请使用如下命令进行部署
31+
```sh
32+
python inference.py --device cpu --task_name seq_cls --model_path ../../finetuned_models/export/model
33+
```
34+
输出打印如下:
35+
```
36+
input data: ['他们告诉我,呃,我最后会被叫到一个人那里去见面。', '我从来没有被告知任何与任何人会面。']
37+
seq cls result:
38+
label: contradiction confidence: 0.9331414103507996
39+
-----------------------------
40+
input data: ['他们告诉我,呃,我最后会被叫到一个人那里去见面。', '我被告知将有一个人被叫进来与我见面。']
41+
seq cls result:
42+
label: entailment confidence: 0.9928839206695557
43+
-----------------------------
44+
input data: ['他们告诉我,呃,我最后会被叫到一个人那里去见面。', '那个人来得有点晚。']
45+
seq cls result:
46+
label: neutral confidence: 0.9880155920982361
47+
-----------------------------
48+
```
49+
infer_cpu.py 脚本中的参数说明:
50+
| 参数 |参数说明 |
51+
|----------|--------------|
52+
|--task_name | 配置任务名称,默认 seq_cls|
53+
|--model_name_or_path | 模型的路径或者名字,默认为 ernie-m|
54+
|--model_path | 用于推理的 Paddle 模型的路径|
55+
|--max_seq_length |最大序列长度,默认为 256|
56+
|--precision_mode | 推理精度,可选 fp32,fp16 或者 int8,当输入非量化模型并设置 int8 时使用动态量化进行加速,默认 fp32 |
57+
|--num_threads | 配置 cpu 的线程数,默认为 cpu 的最大线程数 |
58+
59+
### 2.3 GPU端推理样例
60+
在 GPU 端,请使用如下命令进行部署
61+
```sh
62+
python inference.py --device gpu --task_name seq_cls --model_path ../../finetuned_models/export/model
63+
```
64+
输出打印如下:
65+
```
66+
input data: ['他们告诉我,呃,我最后会被叫到一个人那里去见面。', '我从来没有被告知任何与任何人会面。']
67+
seq cls result:
68+
label: contradiction confidence: 0.932432234287262
69+
-----------------------------
70+
input data: ['他们告诉我,呃,我最后会被叫到一个人那里去见面。', '我被告知将有一个人被叫进来与我见面。']
71+
seq cls result:
72+
label: entailment confidence: 0.9928724765777588
73+
-----------------------------
74+
input data: ['他们告诉我,呃,我最后会被叫到一个人那里去见面。', '那个人来得有点晚。']
75+
seq cls result:
76+
label: neutral confidence: 0.9880901575088501
77+
-----------------------------
78+
```
79+
如果需要 FP16 进行加速,可以设置 precision_mode 为 fp16,具体命令为
80+
```sh
81+
python inference.py --device gpu --task_name seq_cls --model_path ../../finetuned_models/export/model --precision_mode fp16
82+
```
83+
infer_gpu.py 脚本中的参数说明:
84+
| 参数 |参数说明 |
85+
|----------|--------------|
86+
|--task_name | 配置任务名称,可选 seq_cls|
87+
|--model_name_or_path | 模型的路径或者名字,默认为ernie-m-base|
88+
|--model_path | 用于推理的 Paddle 模型的路径|
89+
|--batch_size |最大可测的 batch size,默认为 32|
90+
|--max_seq_length |最大序列长度,默认为 256|
91+
|--precision_mode | 推理精度,可选 fp32,fp16 或者 int8,默认 fp32 |

0 commit comments

Comments
 (0)