Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 12 additions & 53 deletions model_zoo/bert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,66 +268,25 @@ python -u ./export_model.py \
- `model_path` 表示训练模型的保存路径,与训练时的`output_dir`一致。
- `output_path` 表示导出预测模型文件的前缀。保存时会添加后缀(`pdiparams`,`pdiparams.info`,`pdmodel`);除此之外,还会在`output_path`包含的目录下保存tokenizer相关内容。

然后按照如下的方式进行GLUE中的评测任务进行预测(基于Paddle的[Python预测API](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/05_inference_deployment/inference/python_infer_cn.html))
完成模型导出后,可以开始部署。`deploy/python/seq_cls_infer.py` 文件提供了python部署预测示例。可执行以下命令运行部署示例

```shell
python -u ./predict_glue.py \
--task_name SST2 \
--model_type bert \
--model_path ./infer_model/model \
--batch_size 32 \
--max_seq_length 128
python deploy/python/seq_cls_infer.py --model_dir infer_model/ --device gpu --backend paddle
```

其中参数释义如下:
- `task_name` 表示Fine-tuning的任务。
- `model_type` 指示了模型类型,使用BERT模型时设置为bert即可。
- `model_path` 表示预测模型文件的前缀,和上一步导出预测模型中的`output_path`一致。
- `batch_size` 表示每个预测批次的样本数目。
- `max_seq_length` 表示最大句子长度,超过该长度将被截断。

同时支持使用输入样例数据的方式进行预测任务,这里仅以文本情感分类数据[SST-2](https://nlp.stanford.edu/sentiment/index.html)为例,输出样例数据的分类预测结果:
运行后预测结果打印如下:

```shell
python -u ./predict.py \
--model_path ./infer_model/model \
--device gpu \
--max_seq_length 128
```bash
[2023-03-02 08:30:03,877] [ INFO] - We are using <class 'paddlenlp.transformers.bert.fast_tokenizer.BertFastTokenizer'> to load '../../infer_model/'.
[INFO] fastdeploy/runtime/runtime.cc(266)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::GPU.
Batch id: 0, example id: 0, sentence1: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.0003, positive prob: 0.9997.
Batch id: 1, example id: 0, sentence1: the situation in a well-balanced fashion, label: positive, negative prob: 0.0002, positive prob: 0.9998.
Batch id: 2, example id: 0, sentence1: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.0017, positive prob: 0.9983.
Batch id: 3, example id: 0, sentence1: so pat it makes your teeth hurt, label: negative, negative prob: 0.9986, positive prob: 0.0014.
Batch id: 4, example id: 0, sentence1: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: negative, negative prob: 0.9806, positive prob: 0.0194.
```

其中参数释义如下:
- `model_path` 表示预测模型文件的前缀,和上一步导出预测模型中的`output_path`一致。
- `device` 表示训练使用的设备, 'gpu'表示使用GPU, 'xpu'表示使用百度昆仑卡, 'cpu'表示使用CPU。
- `max_seq_length` 表示最大句子长度,超过该长度将被截断。

样例中的待预测数据返回输出的预测结果如下:

```text
Data: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting
Label: positive
Negative prob: 0.0004963805549778044
Positive prob: 0.9995037317276001

Data: the situation in a well-balanced fashion
Label: positive
Negative prob: 0.000471479695988819
Positive prob: 0.9995285272598267

Data: at achieving the modest , crowd-pleasing goals it sets for itself
Label: positive
Negative prob: 0.0019163173856213689
Positive prob: 0.998083770275116

Data: so pat it makes your teeth hurt
Label: negative
Negative prob: 0.9988648295402527
Positive prob: 0.0011351780267432332

Data: this new jangle of noise , mayhem and stupidity must be a serious contender for the title .
Label: negative
Negative prob: 0.9884825348854065
Positive prob: 0.011517543345689774
```
更多详细用法可参考 [Python 部署](deploy/python/README.md)。

## 扩展

Expand Down
139 changes: 139 additions & 0 deletions model_zoo/bert/deploy/python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# FastDeploy BERT 模型 Python 部署示例

在部署前,参考 [FastDeploy SDK 安装文档](https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/cn/build_and_install/download_prebuilt_libraries.md)安装 FastDeploy Python SDK。

本目录下分别提供 `seq_cls_infer.py` 快速完成在 CPU/GPU 的 GLUE 文本分类任务的 Python 部署示例。

## 依赖安装

直接执行以下命令安装部署示例的依赖。

```bash
# 安装 fast_tokenizer 以及 GPU 版本 fastdeploy
pip install fast-tokenizer-python fastdeploy-gpu-python -f https://www.paddlepaddle.org.cn/whl/fastdeploy.html
```

## 快速开始

以下示例展示如何基于 FastDeploy 库完成 BERT 模型在 GLUE SST-2 数据集上进行自然语言推断任务的 Python 预测部署,可通过命令行参数`--device`以及`--backend`指定运行在不同的硬件以及推理引擎后端,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [BERT 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/bert/infer_model`(用户可按实际情况设置)。


```bash
# CPU 推理
python seq_cls_infer.py --model_dir ../../infer_model/ --device cpu --backend paddle
# GPU 推理
python seq_cls_infer.py --model_dir ../../infer_model/ --device gpu --backend paddle
```

运行完成后返回的结果如下:

```bash
[2023-03-02 08:30:03,877] [ INFO] - We are using <class 'paddlenlp.transformers.bert.fast_tokenizer.BertFastTokenizer'> to load '../../infer_model/'.
[INFO] fastdeploy/runtime/runtime.cc(266)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::GPU.
Batch id: 0, example id: 0, sentence1: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.0003, positive prob: 0.9997.
Batch id: 1, example id: 0, sentence1: the situation in a well-balanced fashion, label: positive, negative prob: 0.0002, positive prob: 0.9998.
Batch id: 2, example id: 0, sentence1: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.0017, positive prob: 0.9983.
Batch id: 3, example id: 0, sentence1: so pat it makes your teeth hurt, label: negative, negative prob: 0.9986, positive prob: 0.0014.
Batch id: 4, example id: 0, sentence1: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: negative, negative prob: 0.9806, positive prob: 0.0194.
```

## 参数说明

| 参数 |参数说明 |
|----------|--------------|
|--model_dir | 指定部署模型的目录, |
|--batch_size |输入的batch size,默认为 1|
|--max_length |最大序列长度,默认为 128|
|--device | 运行的设备,可选范围: ['cpu', 'gpu'],默认为'cpu' |
|--device_id | 运行设备的id。默认为0。 |
|--cpu_threads | 当使用cpu推理时,指定推理的cpu线程数,默认为1。|
|--backend | 支持的推理后端,可选范围: ['onnx_runtime', 'paddle', 'openvino', 'tensorrt', 'paddle_tensorrt'],默认为'paddle' |
|--use_fp16 | 是否使用FP16模式进行推理。使用tensorrt和paddle_tensorrt后端时可开启,默认为False |
|--use_fast| 是否使用FastTokenizer加速分词阶段。默认为True|

## FastDeploy 高阶用法

FastDeploy 在 Python 端上,提供 `fastdeploy.RuntimeOption.use_xxx()` 以及 `fastdeploy.RuntimeOption.use_xxx_backend()` 接口支持开发者选择不同的硬件、不同的推理引擎进行部署。在不同的硬件上部署 BERT 模型,需要选择硬件所支持的推理引擎进行部署,下表展示如何在不同的硬件上选择可用的推理引擎部署 BERT 模型。

符号说明: (1) ✅: 已经支持; (2) ❔: 正在进行中; (3) N/A: 暂不支持;

<table>
<tr>
<td align=center> 硬件</td>
<td align=center> 硬件对应的接口</td>
<td align=center> 可用的推理引擎 </td>
<td align=center> 推理引擎对应的接口 </td>
<td align=center> 是否支持 Paddle 新格式量化模型 </td>
<td align=center> 是否支持 FP16 模式 </td>
</tr>
<tr>
<td rowspan=3 align=center> CPU </td>
<td rowspan=3 align=center> use_cpu() </td>
<td align=center> Paddle Inference </td>
<td align=center> use_paddle_infer_backend() </td>
<td align=center> ✅ </td>
<td align=center> N/A </td>
</tr>
<tr>
<td align=center> ONNX Runtime </td>
<td align=center> use_ort_backend() </td>
<td align=center> ✅ </td>
<td align=center> N/A </td>
</tr>
<tr>
<td align=center> OpenVINO </td>
<td align=center> use_openvino_backend() </td>
<td align=center> ❔ </td>
<td align=center> N/A </td>
</tr>
<tr>
<td rowspan=4 align=center> GPU </td>
<td rowspan=4 align=center> use_gpu() </td>
<td align=center> Paddle Inference </td>
<td align=center> use_paddle_infer_backend() </td>
<td align=center> ✅ </td>
<td align=center> N/A </td>
</tr>
<tr>
<td align=center> ONNX Runtime </td>
<td align=center> use_ort_backend() </td>
<td align=center> ✅ </td>
<td align=center> ❔ </td>
</tr>
<tr>
<td align=center> Paddle TensorRT </td>
<td align=center> use_paddle_infer_backend() + paddle_infer_option.enable_trt = True </td>
<td align=center> ✅ </td>
<td align=center> ✅ </td>
</tr>
<tr>
<td align=center> TensorRT </td>
<td align=center> use_trt_backend() </td>
<td align=center> ✅ </td>
<td align=center> ✅ </td>
</tr>
<tr>
<td align=center> 昆仑芯 XPU </td>
<td align=center> use_kunlunxin() </td>
<td align=center> Paddle Lite </td>
<td align=center> use_paddle_lite_backend() </td>
<td align=center> N/A </td>
<td align=center> ✅ </td>
</tr>
<tr>
<td align=center> 华为 昇腾 </td>
<td align=center> use_ascend() </td>
<td align=center> Paddle Lite </td>
<td align=center> use_paddle_lite_backend() </td>
<td align=center> ❔ </td>
<td align=center> ✅ </td>
</tr>
<tr>
<td align=center> Graphcore IPU </td>
<td align=center> use_ipu() </td>
<td align=center> Paddle Inference </td>
<td align=center> use_paddle_infer_backend() </td>
<td align=center> ❔ </td>
<td align=center> N/A </td>
</tr>
</table>
159 changes: 159 additions & 0 deletions model_zoo/bert/deploy/python/seq_cls_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import distutils.util
import os

import fastdeploy as fd
import numpy as np

from paddlenlp.transformers import AutoTokenizer


def parse_arguments():
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", required=True, help="The directory of model.")
parser.add_argument("--vocab_path", type=str, default="", help="The path of tokenizer vocab.")
parser.add_argument("--model_prefix", type=str, default="model", help="The model and params file prefix.")
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["gpu", "cpu"],
help="Type of inference device, support 'cpu' or 'gpu'.",
)
parser.add_argument(
"--backend",
type=str,
default="paddle",
choices=["onnx_runtime", "paddle", "openvino", "tensorrt", "paddle_tensorrt"],
help="The inference runtime backend.",
)
parser.add_argument("--cpu_threads", type=int, default=1, help="Number of threads to predict when using cpu.")
parser.add_argument("--device_id", type=int, default=0, help="Select which gpu device to train model.")
parser.add_argument("--batch_size", type=int, default=1, help="The batch size of data.")
parser.add_argument("--max_length", type=int, default=128, help="The max length of sequence.")
parser.add_argument("--log_interval", type=int, default=10, help="The interval of logging.")
parser.add_argument("--use_fp16", type=distutils.util.strtobool, default=False, help="Wheter to use FP16 mode")
parser.add_argument(
"--use_fast",
type=distutils.util.strtobool,
default=True,
help="Whether to use fast_tokenizer to accelarate the tokenization.",
)
return parser.parse_args()


def batchfy_text(texts, batch_size):
batch_texts = []
batch_start = 0
while batch_start < len(texts):
batch_texts += [texts[batch_start : min(batch_start + batch_size, len(texts))]]
batch_start += batch_size
return batch_texts


class Predictor(object):
def __init__(self, args):
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, use_fast=args.use_fast)
self.runtime = self.create_fd_runtime(args)
self.batch_size = args.batch_size
self.max_length = args.max_length

def create_fd_runtime(self, args):
option = fd.RuntimeOption()
model_path = os.path.join(args.model_dir, args.model_prefix + ".pdmodel")
params_path = os.path.join(args.model_dir, args.model_prefix + ".pdiparams")
option.set_model_path(model_path, params_path)
if args.device == "cpu":
option.use_cpu()
option.set_cpu_thread_num(args.cpu_threads)
else:
option.use_gpu(args.device_id)
if args.backend == "paddle":
option.use_paddle_infer_backend()
elif args.backend == "onnx_runtime":
option.use_ort_backend()
elif args.backend == "openvino":
option.use_openvino_backend()
else:
option.use_trt_backend()
if args.backend == "paddle_tensorrt":
option.use_paddle_infer_backend()
option.paddle_infer_option.collect_trt_shape = True
option.paddle_infer_option.enable_trt = True
trt_file = os.path.join(args.model_dir, "model.trt")
option.trt_option.set_shape(
"input_ids", [1, 1], [args.batch_size, args.max_length], [args.batch_size, args.max_length]
)
option.trt_option.set_shape(
"token_type_ids", [1, 1], [args.batch_size, args.max_length], [args.batch_size, args.max_length]
)
if args.use_fp16:
option.trt_option.enable_fp16 = True
trt_file = trt_file + ".fp16"
option.trt_option.serialize_file = trt_file
return fd.Runtime(option)

def preprocess(self, text):
data = self.tokenizer(text, max_length=self.max_length, padding=True, truncation=True)
input_ids_name = self.runtime.get_input_info(0).name
token_type_ids_name = self.runtime.get_input_info(1).name
input_map = {
input_ids_name: np.array(data["input_ids"], dtype="int64"),
token_type_ids_name: np.array(data["token_type_ids"], dtype="int64"),
}
return input_map

def infer(self, input_map):
results = self.runtime.infer(input_map)
return results

def postprocess(self, infer_data):
logits = np.array(infer_data[0])
max_value = np.max(logits, axis=1, keepdims=True)
exp_data = np.exp(logits - max_value)
probs = exp_data / np.sum(exp_data, axis=1, keepdims=True)
out_dict = {"label": probs.argmax(axis=-1), "confidence": probs}
return out_dict

def predict(self, texts):
input_map = self.preprocess(texts)
infer_result = self.infer(input_map)
output = self.postprocess(infer_result)
return output


if __name__ == "__main__":
args = parse_arguments()
predictor = Predictor(args)
texts_ds = [
"against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting",
"the situation in a well-balanced fashion",
"at achieving the modest , crowd-pleasing goals it sets for itself",
"so pat it makes your teeth hurt",
"this new jangle of noise , mayhem and stupidity must be a serious contender for the title .",
]
label_map = {0: "negative", 1: "positive"}
batch_texts = batchfy_text(texts_ds, args.batch_size)
for bs, texts in enumerate(batch_texts):
outputs = predictor.predict(texts)
for i, sentence1 in enumerate(texts):
print(
f"Batch id: {bs}, example id: {i}, sentence1: {sentence1}, "
f"label: {label_map[outputs['label'][i]]}, negative prob: {outputs['confidence'][i][0]:.4f}, "
f"positive prob: {outputs['confidence'][i][1]:.4f}."
)
Loading