Skip to content

Commit 6043c7f

Browse files
authored
[Paddle TensorRT] fix pd_op.pool2d (PaddlePaddle#69864)
* fix pool2d * fix
1 parent 87f6360 commit 6043c7f

File tree

4 files changed

+355
-45
lines changed

4 files changed

+355
-45
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,60 @@ class Pool2dOpPattern
331331
}
332332
}
333333
}
334+
335+
auto ceil_mode = op->attribute<pir::BoolAttribute>("ceil_mode").data();
336+
auto global_pooling =
337+
op->attribute<pir::BoolAttribute>("global_pooling").data();
338+
std::string padding_algorithm =
339+
op->attribute<pir::StrAttribute>("padding_algorithm").AsString();
340+
// TODO(Lizexu): The general plugin approach for entering TensorRT has not
341+
// been supported yet.
342+
auto adaptive = op->attribute<pir::BoolAttribute>("adaptive").data();
343+
if (adaptive) {
344+
VLOG(3)
345+
<< "The adaptive is true pd_op.pool2d is not supported by trt now";
346+
return false;
347+
}
348+
// TODO(Lizexu): This piece of code exists in the old IR-TRT implementation
349+
// but is not covered by unit tests, raising suspicions about its
350+
// correctness. In the PIR-TRT implementation, following the same approach
351+
// causes precision issues. For now, we will exclude it from entering
352+
// TensorRT.
353+
pir::Value input = op.operand_source(0);
354+
auto kernel_size_attr =
355+
full_int_array_op->attribute<pir::ArrayAttribute>("value");
356+
std::vector<int64_t> kernel_size;
357+
for (const auto &attr : kernel_size_attr.AsVector()) {
358+
kernel_size.push_back(attr.dyn_cast<pir::Int64Attribute>().data());
359+
}
360+
361+
auto input_type = input.type().dyn_cast<paddle::dialect::DenseTensorType>();
362+
auto input_dims = input_type.dims();
363+
int g_post_pad_h = 0;
364+
int g_post_pad_w = 0;
365+
int input_height = input_dims[input_dims.size() - 2];
366+
int input_width = input_dims[input_dims.size() - 1];
367+
std::vector<int32_t> strides;
368+
auto strides_attr = op->attribute<pir::ArrayAttribute>("strides");
369+
for (const auto &attr : strides_attr.AsVector()) {
370+
strides.push_back(attr.dyn_cast<pir::Int32Attribute>().data());
371+
}
372+
if (input_height > 0 &&
373+
input_height - kernel_size[0] + 2 * paddings[0] < 0) {
374+
g_post_pad_h = strides[0] - 1;
375+
}
376+
if (input_width > 0 && input_width - kernel_size[1] + 2 * paddings[1] < 0) {
377+
g_post_pad_w = strides[1] - 1;
378+
}
379+
if (!adaptive && !global_pooling && !ceil_mode) {
380+
if (padding_algorithm != "SAME" &&
381+
((g_post_pad_h > 0 && input_height > 0) ||
382+
(g_post_pad_w > 0 && input_width > 0))) {
383+
VLOG(3) << "The pool2d op meets the condition that may cause precision "
384+
"issues in TRT. Skip TRT conversion.";
385+
return false;
386+
}
387+
}
334388
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
335389
return true;
336390
}

python/paddle/tensorrt/impls/pooling.py

Lines changed: 115 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -21,68 +21,139 @@
2121
@converter_registry.register("pd_op.pool2d", trt_version="8.x")
2222
def pool2d_converter(network, paddle_op, inputs):
2323
input_tensor = inputs[0]
24-
pooling_type = paddle_op.attrs().get("pooling_type", "max")
25-
padding = paddle_op.attrs().get("paddings", [0, 0])
26-
stride = paddle_op.attrs().get("strides", [1, 1])
27-
ceil_mode = paddle_op.attrs().get("ceil_mode", False)
28-
exclusive = paddle_op.attrs().get("exclusive")
29-
adaptive = paddle_op.attrs().get("adaptive")
30-
padding_algorithm = paddle_op.attrs().get("padding_algorithm")
3124

32-
input_shape = input_tensor.shape
25+
input_shape = paddle_op.operands()[0].source().shape
26+
input_dims = len(input_shape)
27+
28+
global_pooling = paddle_op.attrs().get("global_pooling", False)
29+
pool_type = paddle_op.attrs().get("pooling_type")
30+
strides = paddle_op.attrs().get("strides")
31+
paddings = paddle_op.attrs().get("paddings")
32+
exclusive = paddle_op.attrs().get("exclusive", True)
33+
ceil_mode = paddle_op.attrs().get("ceil_mode", False)
34+
adaptive = paddle_op.attrs().get("adaptive", False)
35+
padding_algorithm = paddle_op.attrs().get("padding_algorithm", "EXPLICIT")
3336

34-
# TODO attention for these codes
3537
if not paddle_op.attrs().get("kernel_size") and len(inputs) == 2:
36-
# the size of pool2d inputs is 2, means kernel size is the second input.
37-
# kernel_size_tensor = inputs[1]
3838
full_int_op = paddle_op.operands()[1].source().get_defining_op()
3939
if full_int_op.name() == "pd_op.full_int_array":
4040
kernel_size = full_int_op.attrs().get("value")
4141
else:
4242
raise Exception(
43-
"the defining op of kernel size must be pd_op.full_int_array"
43+
"The defining op of kernel size must be pd_op.full_int_array"
4444
)
4545
else:
4646
kernel_size = paddle_op.attrs().get("kernel_size")
4747

48-
if len(stride) == 0 or stride[0] is None:
49-
stride = kernel_size
48+
nv_pool_type = trt.PoolingType.MAX
49+
reduce_operation = trt.ReduceOperation.MAX
50+
if pool_type == "max":
51+
nv_pool_type = trt.PoolingType.MAX
52+
reduce_operation = trt.ReduceOperation.MAX
53+
elif pool_type == "avg":
54+
nv_pool_type = trt.PoolingType.AVERAGE
55+
reduce_operation = trt.ReduceOperation.AVG
5056

51-
if pooling_type == "max":
52-
pooling_type = trt.PoolingType.MAX
53-
elif pooling_type == "avg":
54-
pooling_type = trt.PoolingType.AVERAGE
55-
else:
56-
raise ValueError(f"Unsupported pooling type: {pooling_type}")
57+
if global_pooling or adaptive:
58+
paddings = [0] * len(paddings)
5759

5860
if padding_algorithm == "VALID":
59-
padding = [0, 0]
60-
61-
if adaptive:
62-
output_size = kernel_size
63-
stride = tuple(input_shape[-2 + i] // output_size[i] for i in range(2))
64-
kernel_size = tuple(
65-
input_shape[-2 + i] - (output_size[i] - 1) * stride[i]
66-
for i in range(2)
61+
paddings = [0] * len(paddings)
62+
63+
nv_paddings = trt.DimsHW(paddings[0], paddings[1])
64+
nv_ksize = trt.DimsHW(kernel_size[0], kernel_size[1])
65+
nv_strides = trt.DimsHW(strides[0], strides[1])
66+
67+
layer = None
68+
g_pre_pad = trt.DimsHW(0, 0)
69+
g_post_pad = trt.DimsHW(0, 0)
70+
71+
if (
72+
input_shape[input_dims - 2] > 0
73+
and input_shape[input_dims - 2] - kernel_size[0] + 2 * paddings[0] < 0
74+
):
75+
g_post_pad.h = strides[0] - 1
76+
if (
77+
input_shape[input_dims - 1] > 0
78+
and input_shape[input_dims - 1] - kernel_size[1] + 2 * paddings[1] < 0
79+
):
80+
g_post_pad.w = strides[1] - 1
81+
82+
real_paddings = paddings.copy()
83+
for i in range(2):
84+
copy_pad = paddings[i]
85+
real_paddings.insert(2 * i + 1, copy_pad)
86+
87+
if padding_algorithm == "SAME":
88+
for i in range(2):
89+
copy_pad = paddings[2 * i]
90+
paddings.insert(2 * i + 1, copy_pad)
91+
92+
for i in range(2):
93+
out_size = (input_shape[2 + i] + strides[i] - 1) // strides[i]
94+
pad_sum = max(
95+
(out_size - 1) * strides[i]
96+
+ kernel_size[i]
97+
- input_shape[2 + i],
98+
0,
99+
)
100+
pad_0 = pad_sum // 2
101+
pad_1 = pad_sum - pad_0
102+
paddings[2 * i] = pad_0
103+
paddings[2 * i + 1] = pad_1
104+
real_paddings = paddings.copy()
105+
106+
paddings = [paddings[i] for i in range(len(paddings)) if i % 2 == 0]
107+
108+
if padding_algorithm == "VALID":
109+
read_paddings = [0] * len(real_paddings)
110+
111+
if not adaptive and not global_pooling and not ceil_mode:
112+
if padding_algorithm != "SAME" and (
113+
(g_post_pad.h > 0 and input_shape[input_dims - 2] > 0)
114+
or (g_post_pad.w > 0 and input_shape[input_dims - 1] > 0)
115+
):
116+
pad_layer = network.add_padding_nd(
117+
input=input_tensor,
118+
pre_padding=tuple(g_pre_pad),
119+
post_padding=tuple(g_post_pad),
120+
)
121+
input_tensor = pad_layer.get_output(0)
122+
pooling_layer = network.add_pooling_nd(
123+
input=input_tensor, type=nv_pool_type, window_size=nv_ksize
67124
)
125+
pooling_layer.stride_nd = nv_strides
126+
pooling_layer.padding_nd = nv_paddings
127+
pooling_layer.average_count_excludes_padding = exclusive
128+
if padding_algorithm == "SAME":
129+
pooling_layer.padding_mode = trt.PaddingMode.SAME_UPPER
68130

69-
pool_layer = network.add_pooling_nd(
70-
input_tensor, pooling_type, window_size=kernel_size
131+
layer = pooling_layer
132+
elif not adaptive and not global_pooling and ceil_mode:
133+
pooling_layer = network.add_pooling_nd(
134+
input=input_tensor, type=nv_pool_type, window_size=nv_ksize
135+
)
136+
pooling_layer.stride_nd = nv_strides
137+
pooling_layer.padding_nd = nv_paddings
138+
pooling_layer.average_count_excludes_padding = exclusive
139+
if padding_algorithm == "SAME":
140+
pooling_layer.padding_mode = trt.PaddingMode.SAME_UPPER
141+
else:
142+
pooling_layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP
143+
layer = pooling_layer
144+
elif global_pooling and not adaptive:
145+
reduce_axes = (1 << (input_dims - 2)) | (1 << (input_dims - 1))
146+
reduce_layer = network.add_reduce(
147+
input=input_tensor,
148+
op=reduce_operation,
149+
axes=reduce_axes,
150+
keep_dims=True,
71151
)
72-
pool_layer.stride_nd = stride
73-
if pooling_type == "max":
74-
pool_layer.padding_nd = padding
152+
layer = reduce_layer
75153
else:
76-
pool_layer = network.add_pooling(
77-
input_tensor, pooling_type, window_size=kernel_size
154+
raise NotImplementedError(
155+
"The combination of attributes is not supported yet."
78156
)
79-
pool_layer.stride = stride
80-
pool_layer.padding = padding
81-
if exclusive:
82-
pool_layer.average_count_excludes_padding = True
83-
else:
84-
pool_layer.average_count_excludes_padding = False
85-
if ceil_mode:
86-
pool_layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP
87157

88-
return pool_layer.get_output(0)
158+
output_tensor = layer.get_output(0)
159+
return output_tensor

test/tensorrt/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@ if(NOT WIN32 AND TENSORRT_FOUND)
2828
set_tests_properties(test_converter_linalg PROPERTIES TIMEOUT "100")
2929
set_tests_properties(test_converter_search PROPERTIES TIMEOUT "300")
3030
set_tests_properties(test_converter_logic PROPERTIES TIMEOUT "300")
31-
31+
set_tests_properties(test_converter_pooling PROPERTIES TIMEOUT "300")
3232
endif()

0 commit comments

Comments
 (0)