Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/inference/tensorrt/pir/generic_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,10 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
(data_type == nvinfer1::DataType::kHALF));

phi_kernel_contexts_[data_type]->ClearInputOutput();

auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(place));
phi_kernel_contexts_[data_type]->SetDeviceContext(dev_ctx);

auto& vec_kernel_fn_tensor_params = op_yaml_info_->TensorParams(true);
int kernel_input_count = vec_kernel_fn_tensor_params.size();
for (int i = 0; i < getNbInputs(); i++) {
Expand Down
24 changes: 17 additions & 7 deletions python/paddle/tensorrt/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@
from paddle.base.core import get_value_shape_range_info
from paddle.base.log_helper import get_logger

from .impls.core import * # noqa: F403
from .impls.activation import * # noqa: F403
from .impls.conv import * # noqa: F403
from .impls.creation import * # noqa: F403
from .impls.linalg import * # noqa: F403
from .impls.manipulation import * # noqa: F403
from .impls.math import * # noqa: F403
from .impls.norm import * # noqa: F403
from .impls.pooling import * # noqa: F403
from .impls.search import * # noqa: F403
from .register import converter_registry
from .util import map_dtype

Expand Down Expand Up @@ -200,12 +208,11 @@ def convert_subgraph_to_trt(self, program, group_op):
f'{source_id} not found in value_to_trt_tensor'
)

layer = self.convert(network, op, operands)
trt_outs = self.convert(network, op, operands)

for idx, result in enumerate(op.results()):
# TODO In some cases, the output index (idx) of a Paddle OP may not necessarily be the same as the output index of TensorRT
if idx < layer.num_outputs:
value_to_trt_tensor[result.id] = layer.get_output(idx)
if idx < len(trt_outs):
value_to_trt_tensor[result.id] = trt_outs[idx]
else:
value_to_trt_tensor[result.id] = None
out_shapes = []
Expand Down Expand Up @@ -307,8 +314,11 @@ def convert(self, network, paddle_op, inputs):
raise NotImplementedError(
f"Converter for {op_name} not implemented."
)
out = converter_func(network, paddle_op, inputs)
return out
outs = converter_func(network, paddle_op, inputs)
if isinstance(outs, tuple):
return outs
else:
return tuple(outs)

def convert_program_to_trt(self):
for op in self.program.global_block().ops:
Expand Down
75 changes: 75 additions & 0 deletions python/paddle/tensorrt/impls/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys

import numpy as np

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
if parent_dir not in sys.path:
sys.path.append(parent_dir)

import tensorrt as trt

from paddle.base.log_helper import get_logger
from paddle.tensorrt.register import converter_registry

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
from paddle.tensorrt.converter_utils import (
get_trt_plugin,
)


@converter_registry.register("pd_op.relu", trt_version="8.x")
def relu_converter(network, paddle_op, inputs):
relu_layer = network.add_activation(inputs[0], trt.ActivationType.RELU)
return relu_layer.get_output(0)


@converter_registry.register("pd_op.softmax", trt_version="8.x")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个属于activation吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,softmax api在paddle里是放在activation里边的

def softmax_converter(network, paddle_op, inputs):
axis = paddle_op.attrs().get("axis", 0)
if axis < 0:
axis = len(inputs[0].shape) + axis

softmax_layer = network.add_softmax(inputs[0])
softmax_layer.axes = 1 << axis
return softmax_layer.get_output(0)


@converter_registry.register("pd_op.gelu", trt_version="8.x")
def gelu_converter(network, paddle_op, inputs):
input_val = inputs[0]
approximate = paddle_op.attrs()["approximate"]
if approximate:
raise RuntimeError(
"GeLU converter currently doesn't support fast gelu compute"
)

plugin_name = "CustomGeluPluginDynamic"
type_id = trt.PluginField(
"type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
)

filed_collection = trt.PluginFieldCollection([type_id])
plugin_version = "1"

plugin = get_trt_plugin(plugin_name, filed_collection, plugin_version)

layer = network.add_plugin_v2([input_val], plugin)
return layer.get_output(0)
54 changes: 54 additions & 0 deletions python/paddle/tensorrt/impls/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
if parent_dir not in sys.path:
sys.path.append(parent_dir)

import tensorrt as trt

from paddle.base.log_helper import get_logger
from paddle.tensorrt.register import converter_registry

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)


@converter_registry.register("pd_op.conv2d", trt_version="8.x")
def conv2d_converter(network, paddle_op, inputs):
input_tensor, weight = inputs
weight_shape = paddle_op.operands()[1].source().shape

padding = paddle_op.attrs().get("paddings", [0, 0])
stride = paddle_op.attrs().get("strides", [1, 1])
dilation = paddle_op.attrs().get("dilations", [1, 1])
groups = paddle_op.attrs().get("groups", 1)

# weight_tensor = network.add_constant(weight_shape, weight).get_output(0)
kernel_shape = trt.Dims((weight_shape[2], weight_shape[3]))

conv_layer = network.add_convolution_nd(
input_tensor, weight_shape[0], kernel_shape, weight
)
conv_layer.stride_nd = stride
conv_layer.padding_nd = padding
conv_layer.dilation_nd = dilation
conv_layer.num_groups = groups

return conv_layer.get_output(0)
Loading