Skip to content

Commit eef6557

Browse files
authored
[Inference] Split core converter file into several files (#67877)
* split core * fix bug * delete comment * delete code
1 parent 8efbe54 commit eef6557

File tree

12 files changed

+847
-577
lines changed

12 files changed

+847
-577
lines changed

paddle/fluid/inference/tensorrt/pir/generic_plugin.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,10 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
573573
(data_type == nvinfer1::DataType::kHALF));
574574

575575
phi_kernel_contexts_[data_type]->ClearInputOutput();
576+
577+
auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(place));
578+
phi_kernel_contexts_[data_type]->SetDeviceContext(dev_ctx);
579+
576580
auto& vec_kernel_fn_tensor_params = op_yaml_info_->TensorParams(true);
577581
int kernel_input_count = vec_kernel_fn_tensor_params.size();
578582
for (int i = 0; i < getNbInputs(); i++) {

python/paddle/tensorrt/converter.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@
2323
from paddle.base.core import get_value_shape_range_info
2424
from paddle.base.log_helper import get_logger
2525

26-
from .impls.core import * # noqa: F403
26+
from .impls.activation import * # noqa: F403
27+
from .impls.conv import * # noqa: F403
28+
from .impls.creation import * # noqa: F403
29+
from .impls.linalg import * # noqa: F403
30+
from .impls.manipulation import * # noqa: F403
31+
from .impls.math import * # noqa: F403
32+
from .impls.norm import * # noqa: F403
33+
from .impls.pooling import * # noqa: F403
34+
from .impls.search import * # noqa: F403
2735
from .register import converter_registry
2836
from .util import map_dtype
2937

@@ -200,12 +208,11 @@ def convert_subgraph_to_trt(self, program, group_op):
200208
f'{source_id} not found in value_to_trt_tensor'
201209
)
202210

203-
layer = self.convert(network, op, operands)
211+
trt_outs = self.convert(network, op, operands)
204212

205213
for idx, result in enumerate(op.results()):
206-
# TODO In some cases, the output index (idx) of a Paddle OP may not necessarily be the same as the output index of TensorRT
207-
if idx < layer.num_outputs:
208-
value_to_trt_tensor[result.id] = layer.get_output(idx)
214+
if idx < len(trt_outs):
215+
value_to_trt_tensor[result.id] = trt_outs[idx]
209216
else:
210217
value_to_trt_tensor[result.id] = None
211218
out_shapes = []
@@ -307,8 +314,11 @@ def convert(self, network, paddle_op, inputs):
307314
raise NotImplementedError(
308315
f"Converter for {op_name} not implemented."
309316
)
310-
out = converter_func(network, paddle_op, inputs)
311-
return out
317+
outs = converter_func(network, paddle_op, inputs)
318+
if isinstance(outs, tuple):
319+
return outs
320+
else:
321+
return tuple(outs)
312322

313323
def convert_program_to_trt(self):
314324
for op in self.program.global_block().ops:
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import logging
15+
import os
16+
import sys
17+
18+
import numpy as np
19+
20+
current_dir = os.path.dirname(os.path.abspath(__file__))
21+
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
22+
if parent_dir not in sys.path:
23+
sys.path.append(parent_dir)
24+
25+
import tensorrt as trt
26+
27+
from paddle.base.log_helper import get_logger
28+
from paddle.tensorrt.register import converter_registry
29+
30+
_logger = get_logger(
31+
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
32+
)
33+
from paddle.tensorrt.converter_utils import (
34+
get_trt_plugin,
35+
)
36+
37+
38+
@converter_registry.register("pd_op.relu", trt_version="8.x")
39+
def relu_converter(network, paddle_op, inputs):
40+
relu_layer = network.add_activation(inputs[0], trt.ActivationType.RELU)
41+
return relu_layer.get_output(0)
42+
43+
44+
@converter_registry.register("pd_op.softmax", trt_version="8.x")
45+
def softmax_converter(network, paddle_op, inputs):
46+
axis = paddle_op.attrs().get("axis", 0)
47+
if axis < 0:
48+
axis = len(inputs[0].shape) + axis
49+
50+
softmax_layer = network.add_softmax(inputs[0])
51+
softmax_layer.axes = 1 << axis
52+
return softmax_layer.get_output(0)
53+
54+
55+
@converter_registry.register("pd_op.gelu", trt_version="8.x")
56+
def gelu_converter(network, paddle_op, inputs):
57+
input_val = inputs[0]
58+
approximate = paddle_op.attrs()["approximate"]
59+
if approximate:
60+
raise RuntimeError(
61+
"GeLU converter currently doesn't support fast gelu compute"
62+
)
63+
64+
plugin_name = "CustomGeluPluginDynamic"
65+
type_id = trt.PluginField(
66+
"type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
67+
)
68+
69+
filed_collection = trt.PluginFieldCollection([type_id])
70+
plugin_version = "1"
71+
72+
plugin = get_trt_plugin(plugin_name, filed_collection, plugin_version)
73+
74+
layer = network.add_plugin_v2([input_val], plugin)
75+
return layer.get_output(0)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import logging
15+
import os
16+
import sys
17+
18+
current_dir = os.path.dirname(os.path.abspath(__file__))
19+
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
20+
if parent_dir not in sys.path:
21+
sys.path.append(parent_dir)
22+
23+
import tensorrt as trt
24+
25+
from paddle.base.log_helper import get_logger
26+
from paddle.tensorrt.register import converter_registry
27+
28+
_logger = get_logger(
29+
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
30+
)
31+
32+
33+
@converter_registry.register("pd_op.conv2d", trt_version="8.x")
34+
def conv2d_converter(network, paddle_op, inputs):
35+
input_tensor, weight = inputs
36+
weight_shape = paddle_op.operands()[1].source().shape
37+
38+
padding = paddle_op.attrs().get("paddings", [0, 0])
39+
stride = paddle_op.attrs().get("strides", [1, 1])
40+
dilation = paddle_op.attrs().get("dilations", [1, 1])
41+
groups = paddle_op.attrs().get("groups", 1)
42+
43+
# weight_tensor = network.add_constant(weight_shape, weight).get_output(0)
44+
kernel_shape = trt.Dims((weight_shape[2], weight_shape[3]))
45+
46+
conv_layer = network.add_convolution_nd(
47+
input_tensor, weight_shape[0], kernel_shape, weight
48+
)
49+
conv_layer.stride_nd = stride
50+
conv_layer.padding_nd = padding
51+
conv_layer.dilation_nd = dilation
52+
conv_layer.num_groups = groups
53+
54+
return conv_layer.get_output(0)

0 commit comments

Comments
 (0)