Skip to content

Commit 1fcb469

Browse files
authored
【SCU】[Paddle TensorRT No.28] Add pd_op.logical_xor converter (#70170)
* test * test * test * test
1 parent 2da0886 commit 1fcb469

File tree

3 files changed

+44
-0
lines changed

3 files changed

+44
-0
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,8 @@ class LogicalCommonOpPattern : public pir::OpRewritePattern<OpType> {
12271227
return true;
12281228
}
12291229
};
1230+
using LogicalXorOpPattern =
1231+
LogicalCommonOpPattern<paddle::dialect::LogicalXorOp>;
12301232
using LogicalOrOpPattern = LogicalCommonOpPattern<paddle::dialect::LogicalOrOp>;
12311233
using LogicalOr_OpPattern =
12321234
LogicalCommonOpPattern<paddle::dialect::LogicalOr_Op>;
@@ -2268,6 +2270,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
22682270
ps.Add(std::make_unique<SetValueWithTensor_OpPattern>(context));
22692271
ps.Add(std::make_unique<EqualOpPattern>(context));
22702272
ps.Add(std::make_unique<NotEqualOpPattern>(context));
2273+
ps.Add(std::make_unique<LogicalXorOpPattern>(context));
22712274
ps.Add(std::make_unique<TanhOpPattern>(context));
22722275
ps.Add(std::make_unique<CeluOpPattern>(context));
22732276
ps.Add(std::make_unique<OneHotOpPattern>(context));

python/paddle/tensorrt/impls/logic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"pd_op.greater_than": trt.ElementWiseOperation.GREATER,
2424
"pd_op.less_than": trt.ElementWiseOperation.LESS,
2525
"pd_op.equal": trt.ElementWiseOperation.EQUAL,
26+
"pd_op.logical_xor": trt.ElementWiseOperation.XOR,
2627
"pd_op.logical_or": trt.ElementWiseOperation.OR,
2728
"pd_op.logical_or_": trt.ElementWiseOperation.OR,
2829
"pd_op.logical_and": trt.ElementWiseOperation.AND,
@@ -32,6 +33,7 @@
3233
@converter_registry.register("pd_op.greater_than", trt_version="8.x")
3334
@converter_registry.register("pd_op.less_than", trt_version="8.x")
3435
@converter_registry.register("pd_op.equal", trt_version="8.x")
36+
@converter_registry.register("pd_op.logical_xor", trt_version="8.x")
3537
@converter_registry.register("pd_op.logical_or", trt_version="8.x")
3638
@converter_registry.register("pd_op.logical_or_", trt_version="8.x")
3739
@converter_registry.register("pd_op.logical_and", trt_version="8.x")

test/tensorrt/test_converter_logic.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,5 +257,44 @@ def test_trt_result(self):
257257
self.check_marker(expected_result=False)
258258

259259

260+
class TestLogicalXorTRTPattern(TensorRTBaseTest):
261+
def setUp(self):
262+
self.python_api = paddle.logical_xor
263+
264+
def test_trt_result(self):
265+
self.api_args = {
266+
"x": np.random.choice([True, False], size=(3,)).astype("bool"),
267+
"y": np.random.choice([True, False], size=(3,)).astype("bool"),
268+
}
269+
self.program_config = {"feed_list": ["x", "y"]}
270+
self.min_shape = {"x": [1], "y": [1]}
271+
self.max_shape = {"x": [5], "y": [5]}
272+
self.check_trt_result()
273+
274+
def test_trt_diff_shape_result(self):
275+
self.api_args = {
276+
"x": np.random.choice([True, False], size=(2, 3)).astype("bool"),
277+
"y": np.random.choice([True, False], size=(3)).astype("bool"),
278+
}
279+
self.program_config = {"feed_list": ["x", "y"]}
280+
self.min_shape = {"x": [1, 3], "y": [3]}
281+
self.max_shape = {"x": [4, 3], "y": [3]}
282+
self.check_trt_result()
283+
284+
285+
class TestLogicalXorMarker(TensorRTBaseTest):
286+
def setUp(self):
287+
self.python_api = paddle.logical_xor
288+
self.api_args = {
289+
"x": np.random.randn(3).astype("int64"),
290+
"y": np.random.randn(3).astype("int64"),
291+
}
292+
self.program_config = {"feed_list": ["x", "y"]}
293+
self.target_marker_op = "pd_op.logical_xor"
294+
295+
def test_trt_result(self):
296+
self.check_marker(expected_result=False)
297+
298+
260299
if __name__ == '__main__':
261300
unittest.main()

0 commit comments

Comments
 (0)