Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions faster_generation/samples/t5_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

from paddlenlp.transformers import T5ForConditionalGeneration, T5Tokenizer


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--max_length", default=256, type=int, help="Maximum output sequence length.")
parser.add_argument("--beam_size", default=4, type=int, help="The beam size to set.")
parser.add_argument("--use_faster", action="store_true", help="Whether to use faster to predict.")
parser.add_argument("--use_fp16_decoding", action="store_true", help="Whether to use fp16 to predict.")
args = parser.parse_args()
return args


def predict(args):
model_name = "t5-base"

model = T5ForConditionalGeneration.from_pretrained(model_name)
model.eval()
tokenizer = T5Tokenizer.from_pretrained(model_name)

en_text = ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of countless generations of stars: the oldest stars are seen as blue dots. '
input_ids = tokenizer.encode("translate English to French: " + en_text, return_tensors="pd")["input_ids"]

output, _ = model.generate(
input_ids=input_ids,
num_beams=args.beam_size,
max_length=args.max_length,
decode_strategy="beam_search",
use_faster=args.use_faster,
use_fp16_decoding=args.use_fp16_decoding,
)

translation = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)

print("The original sentence: ", en_text)
print("The translation result: ", translation)


if __name__ == "__main__":
args = parse_args()

predict(args)
23 changes: 21 additions & 2 deletions paddlenlp/ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ option(WITH_TRANSFORMER "Compiled with Transformer."
option(WITH_GPT "Compiled with GPT." ON)
option(WITH_OPT "Compiled with OPT." ON)
option(WITH_UNIFIED "Compiled with Unified Transformer." ON)
option(WITH_T5 "Compiled with T5." ON)
option(WITH_SP "Compiled with sentencepiece. Only works when WITH_GPT and ON_INFER is ON." OFF)
option(WITH_DECODER "Compile with Transformer Decoder" ON)
option(WITH_ENCODER "Compile with Transformer Encoder" ON)
Expand Down Expand Up @@ -95,8 +96,12 @@ if(WITH_PEGASUS)
list(APPEND decoding_op_files fusion_pegasus_decoding_op.cc fusion_pegasus_decoding_op.cu)
endif()

if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER AND NOT WITH_ENCODER AND NOT WITH_BART AND NOT WITH_MBART AND NOT WITH_GPTJ AND NOT WITH_PEGASUS)
message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON or/and -DWITH_ENCODER=ON or/and -DWITH_BART=ON or/and -DWITH_MBART=ON or/and -DWITH_GPTJ=ON or/and -DWITH_PEGASUS=ON must be set to use FasterTransformer. ")
if(WITH_T5)
list(APPEND decoding_op_files fusion_t5_decoding_op.cc fusion_t5_decoding_op.cu)
endif()

if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER AND NOT WITH_ENCODER AND NOT WITH_BART AND NOT WITH_MBART AND NOT WITH_GPTJ AND NOT WITH_PEGASUS AND NOT WITH_T5)
message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON or/and -DWITH_ENCODER=ON or/and -DWITH_BART=ON or/and -DWITH_MBART=ON or/and -DWITH_GPTJ=ON or/and -DWITH_PEGASUS=ON or/and -DWITH_T5=ON must be set to use FasterTransformer. ")
endif()

set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
Expand Down Expand Up @@ -277,6 +282,9 @@ file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/utils/common.h common_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/utils/common.h common_dst)

file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/utils/common_structure.h common_structure_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/utils/common_structure.h common_structure_dst)

file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/CMakeLists.txt cmakelists_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/CMakeLists.txt cmakelists_dst)

Expand Down Expand Up @@ -322,6 +330,9 @@ file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/transformer_kernels.cu trans_kernels_cu_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/transformer_kernels.cu trans_kernels_cu_dst)

file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/transformer_kernels.cuh trans_kernels_cuh_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/transformer_kernels.cuh trans_kernels_cuh_dst)

file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.cu masked_multihead_attention_cu_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/masked_multihead_attention.cu masked_multihead_attention_cu_dst)

Expand All @@ -346,6 +357,10 @@ file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cuh attention_kernels_cuh_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/attention_kernels.cuh attention_kernels_cuh_dst)

file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/t5_beamsearch.h t5_bs_h_src)
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/t5_sampling.h t5_spl_h_src)
set(ft_dst ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/)

# Encoder patches.
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/bert_encoder_transformer.h bert_encoder_transformer_h_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/bert_encoder_transformer.h bert_encoder_transformer_h_dst)
Expand All @@ -370,6 +385,7 @@ set(FT_PATCH_COMMAND
printf \\n\\n > blank_lines
&& cp ${allocator_src} ${allocator_dst}
&& cp ${common_src} ${common_dst}
&& cp ${common_structure_src} ${common_structure_dst}
&& cp ${cmakelists_src} ${cmakelists_dst}
&& cp ${topk_kernels_src} ${topk_kernels_dst}
&& cp ${decoding_beamsearch_h_src} ${decoding_beamsearch_h_dst}
Expand All @@ -387,6 +403,8 @@ set(FT_PATCH_COMMAND
&& cp ${opt_h_src} ${opt_h_dst}
&& cp ${gptj_h_src} ${gptj_h_dst}
&& cp ${masked_multihead_attention_h_src} ${masked_multihead_attention_h_dst}
&& cp ${t5_bs_h_src} ${ft_dst}
&& cp ${t5_spl_h_src} ${ft_dst}
&& cat blank_lines ${masked_multihead_attention_utils_h_src} >> ${masked_multihead_attention_utils_h_dst}
&& cat blank_lines ${attention_kernels_cu_src} >> ${attention_kernels_cu_dst}
&& cat blank_lines ${attention_kernels_cuh_src} >> ${attention_kernels_cuh_dst}
Expand All @@ -398,6 +416,7 @@ set(FT_PATCH_COMMAND
&& cat blank_lines ${trans_decoding_kernels_cu_src} >> ${decoding_kernels_cu_dst}
&& cat blank_lines ${open_decoder_cu_src} >> ${open_decoder_cu_dst}
&& cat blank_lines ${open_decoder_cuh_src} >> ${open_decoder_cuh_dst}
&& cat blank_lines ${trans_kernels_cuh_src} >> ${trans_kernels_cuh_dst}
&& sed -i "s/^#define NEW_TRANSPOSE_BATCH_MAJOR 1/#define NEW_TRANSPOSE_BATCH_MAJOR 0/g" ${open_decoder_cu_dst}
&& sed -i "2091,2119d" ${open_attention_cu_dst}
&& rm blank_lines
Expand Down
113 changes: 113 additions & 0 deletions paddlenlp/ops/faster_transformer/sample/t5_export_model_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from pprint import pprint

import paddle

from paddlenlp.ops import FasterT5
from paddlenlp.transformers import T5ForConditionalGeneration, T5Tokenizer
from paddlenlp.utils.log import logger


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path", default="t5-base", type=str, help="The model name to specify the bart to use. "
)
parser.add_argument(
"--inference_model_dir", default="./infer_model/", type=str, help="Path to save inference model of bart. "
)
parser.add_argument("--topk", default=4, type=int, help="The number of candidate to procedure top_k sampling. ")
parser.add_argument(
"--topp", default=1.0, type=float, help="The probability threshold to procedure top_p sampling. "
)
parser.add_argument("--max_out_len", default=256, type=int, help="Maximum output length. ")
parser.add_argument("--temperature", default=1.0, type=float, help="The temperature to set. ")
parser.add_argument("--num_return_sequences", default=1, type=int, help="The number of returned sequences. ")
parser.add_argument("--use_fp16_decoding", action="store_true", help="Whether to use fp16 decoding to predict. ")
parser.add_argument(
"--decoding_strategy",
default="beam_search",
choices=["sampling", "beam_search"],
type=str,
help="The main strategy to decode. ",
)
parser.add_argument("--num_beams", default=4, type=int, help="The number of candidate to procedure beam search. ")
parser.add_argument(
"--diversity_rate", default=0.0, type=float, help="The diversity rate to procedure beam search. "
)
parser.add_argument("--repetition_penalty", default=1.0, type=float, help="The repetition_penalty to set. ")
parser.add_argument("--length_penalty", default=0.0, type=float, help="The length penalty to decode. ")
parser.add_argument("--early_stopping", action="store_true", help="Whether to do early stopping. ")

args = parser.parse_args()
return args


def do_predict(args):
place = "gpu"
place = paddle.set_device(place)

model = T5ForConditionalGeneration.from_pretrained(args.model_name_or_path)
tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path)

# For opening faster_encoder
model.eval()

faster_t5 = FasterT5(model=model, use_fp16_decoding=args.use_fp16_decoding)
# Set evaluate mode
faster_t5.eval()

# Convert dygraph model to static graph model
faster_t5 = paddle.jit.to_static(
faster_t5,
input_spec=[
# input_ids
paddle.static.InputSpec(shape=[None, None], dtype="int32"),
# encoder_output
None,
# seq_len
None,
args.max_out_len, # max_length
0, # min_length
args.topk, # top_k
args.topp, # top_p
args.num_beams, # num_beams
args.decoding_strategy, # decode_strategy
None, # decoder_start_token_id
tokenizer.bos_token_id, # bos_token_id
tokenizer.eos_token_id, # eos_token_id
tokenizer.pad_token_id, # pad_token_id
args.diversity_rate, # diversity_rate
args.temperature, # temperature
args.num_return_sequences, # num_return_sequences
args.length_penalty, # length_penalty
args.early_stopping, # early_stopping
tokenizer.eos_token_id, # forced_eos_token_id
],
)

# Save converted static graph model
paddle.jit.save(faster_t5, os.path.join(args.inference_model_dir, "t5"))
logger.info("T5 has been saved to {}.".format(args.inference_model_dir))


if __name__ == "__main__":
args = parse_args()
pprint(args)

do_predict(args)
94 changes: 94 additions & 0 deletions paddlenlp/ops/faster_transformer/sample/t5_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from pprint import pprint

import paddle.inference as paddle_infer

from paddlenlp.ops.ext_utils import load
from paddlenlp.transformers import T5Tokenizer


def setup_args():
"""Setup arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--inference_model_dir", default="./infer_model/", type=str, help="Path to save inference model of T5. "
)
parser.add_argument("--batch_size", default=1, type=int, help="Batch size. ")

args = parser.parse_args()

return args


def postprocess_response(tokenizer, seq, bos_idx, eos_idx):
"""Post-process the decoded sequence."""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [idx for idx in seq[: eos_pos + 1] if idx != bos_idx and idx != eos_idx]
res = tokenizer.convert_ids_to_string(seq)
return res


def infer(args):
model_name = "t5-base"

tokenizer = T5Tokenizer.from_pretrained(model_name)
inputs = ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of countless generations of stars: the oldest stars are seen as blue dots. '

# Input ids
input_ids = tokenizer.encode("translate English to French: " + inputs, return_tensors="np")["input_ids"]

# Load FasterTransformer lib.
load("FasterTransformer", verbose=True)

config = paddle_infer.Config(
os.path.join(args.inference_model_dir, "t5.pdmodel"), os.path.join(args.inference_model_dir, "t5.pdiparams")
)

config.enable_use_gpu(100, 0)
config.disable_glog_info()
predictor = paddle_infer.create_predictor(config)

input_names = predictor.get_input_names()

# Input ids
input_ids_handle = predictor.get_input_handle(input_names[0])
input_ids_handle.copy_from_cpu(input_ids.astype("int32"))

predictor.run()

output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu()

# [batch_size, num_beams * 2, sequence_length]
output_data = output_data.transpose([1, 2, 0])

# Only use the best sequence.
translation = tokenizer.decode(output_data[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
print("Result:", translation)


if __name__ == "__main__":
args = setup_args()
pprint(args)

infer(args)
2 changes: 2 additions & 0 deletions paddlenlp/ops/faster_transformer/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ add_definitions(-DPADDLE_CUDA)
# Default is 1 in standard c++ when using gcc8.2
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=1)

add_definitions(-w)

if(ON_INFER)
add_definitions(-DPADDLE_ON_INFERENCE)

Expand Down
Loading