Skip to content

Commit 5afcfde

Browse files
authored
[Prim][PIR] fallback mechanism to support dropout op (#66711)
* add a fallback mechanism to support dropout op * add fallback unittest * Rename test file and update CMakeLists.txt * remove unnecessary comments * rename class method
1 parent 2fd3987 commit 5afcfde

File tree

5 files changed

+162
-11
lines changed

5 files changed

+162
-11
lines changed

paddle/fluid/primitive/base/decomp_trans.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,10 @@ void DecompProgram::decomp_block(
480480
auto& builder = *(paddle::dialect::ApiBuilder::Instance().GetBuilder());
481481
builder.set_insertion_point(op);
482482
std::vector<std::vector<pir::Value>> decomp_res = call_decomp_rule(op);
483+
if (decomp_res.size() == 0) {
484+
// if we don't decomp this op, then leave it intact.
485+
continue;
486+
}
483487
std::vector<pir::Value> orig_outs = op->results();
484488
bool is_next_builtin_split_slice = false;
485489

paddle/fluid/primitive/codegen/templates/decomp/generated_decomp.j2

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,7 @@ std::vector<std::vector<pir::Value>> {{class_name}}::Decomp(pir::Operation* op)
8181
->value()
8282
.defining_op();
8383
if ({{item.name}}_define_op->name() != "pd_op.full") {
84-
PADDLE_THROW(
85-
common::errors::Unimplemented("We don't support dynamic tensors "
86-
"attribute {{item.name}} for {{fwd_name}} decomposition "
87-
"for now. "));
84+
return {};
8885
}
8986
Scalar {{item.name}} = {{item.name}}_define_op->attribute("value").dyn_cast<paddle::dialect::ScalarAttribute>().data();
9087

@@ -97,10 +94,7 @@ std::vector<std::vector<pir::Value>> {{class_name}}::Decomp(pir::Operation* op)
9794
->value()
9895
.defining_op();
9996
if ({{item.name}}_define_op->name() != "pd_op.full_int_array") {
100-
PADDLE_THROW(
101-
common::errors::Unimplemented("We don't support dynamic tensors "
102-
"attribute {{item.name}} for {{fwd_name}} decomposition "
103-
"for now. "));
97+
return {};
10498
}
10599
IntArray {{item.name}} = phi::IntArray(
106100
paddle::dialect::GetInt64Vector({{item.name}}_define_op->attribute("value")));
@@ -120,7 +114,7 @@ std::vector<std::vector<pir::Value>> {{class_name}}::Decomp(pir::Operation* op)
120114
}
121115

122116
} else {
123-
PADDLE_THROW(common::errors::Unimplemented("attr is not vector of {{temp_type}} "));
117+
return {};
124118
}
125119
}
126120
{% else %}

paddle/fluid/primitive/composite/composite.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ std::tuple<Tensor, Tensor> dropout_decomp(
856856
// train: out = input * mask / ( 1.0 - p )
857857
if (p.to<float>() == 1.0) {
858858
// Process p=1. for avoid divide zero error (x*mask/(1.0-p))
859-
auto zero = full_scalar<T>(0.0, org_dtype);
859+
auto zero = full_like_decomp<T>(x, 0.0, org_dtype, x.place());
860860
return std::make_tuple(x * zero, cast<T>(zero, DataType::UINT8));
861861
} else {
862862
auto ans = (x * mask) / ones_p;

test/prim/pir_prim/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ set(TEST_PRIM_PURE_PIR_CASES
2222
test_decompose_control_flow
2323
test_decomp_whole_program
2424
test_dynamic_combine1
25-
test_dynamic_combine2)
25+
test_dynamic_combine2
26+
test_decomp_fallback)
2627

2728
foreach(target ${TEST_PRIM_PURE_PIR_CASES})
2829
py_test_modules(
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
19+
import paddle
20+
21+
22+
class TestFallBackBase(unittest.TestCase):
23+
def setUp(self):
24+
self.func_api = None
25+
self.dtype = np.float32
26+
self.tol = 1e-6
27+
28+
29+
def custom_dropout(x, p):
30+
return paddle.nn.functional.dropout(x, p) + 2.0
31+
32+
33+
class TestDropOutFallBack(TestFallBackBase):
34+
def setUp(self):
35+
super().setUp()
36+
self.func_api = custom_dropout
37+
self.x = paddle.to_tensor([[1.0, -2], [3.0, 4]], dtype=self.dtype)
38+
self.p = paddle.to_tensor(0.0, dtype=self.dtype)
39+
40+
def test_fallback(self):
41+
static_func = paddle.jit.to_static(self.func_api, full_graph=True)
42+
dynamic_func = self.func_api
43+
44+
out = static_func(self.x, self.p)
45+
ref_out = dynamic_func(self.x, self.p)
46+
47+
for ref, actual in zip(ref_out, out):
48+
np.testing.assert_allclose(
49+
ref, actual, rtol=self.tol, atol=self.tol
50+
)
51+
52+
53+
def custom_full(shape, value):
54+
return paddle.full_like(shape, value) + 2.0
55+
56+
57+
class TestFullLikeFallBack(TestFallBackBase):
58+
def setUp(self):
59+
super().setUp()
60+
self.func_api = custom_full
61+
self.x = paddle.to_tensor([[1.0, -2], [3.0, 4]], dtype=self.dtype)
62+
self.value = paddle.to_tensor(2, dtype=self.dtype)
63+
64+
def test_fallback(self):
65+
static_func = paddle.jit.to_static(self.func_api, full_graph=True)
66+
dynamic_func = self.func_api
67+
68+
out = static_func(self.x, self.value)
69+
ref_out = dynamic_func(self.x, self.value)
70+
71+
for ref, actual in zip(ref_out, out):
72+
np.testing.assert_allclose(
73+
ref, actual, rtol=self.tol, atol=self.tol
74+
)
75+
76+
77+
def custom_squeeze(x, axis):
78+
return paddle.squeeze(x, axis) + 2.0
79+
80+
81+
class TestSqueezeFallBack(TestFallBackBase):
82+
def setUp(self):
83+
super().setUp()
84+
self.func_api = custom_squeeze
85+
self.x = paddle.rand([5, 1, 10], dtype=self.dtype)
86+
self.axis = paddle.to_tensor(1, dtype=paddle.int64)
87+
88+
def test_fallback(self):
89+
static_func = paddle.jit.to_static(self.func_api, full_graph=True)
90+
dynamic_func = self.func_api
91+
92+
out = static_func(self.x, self.axis)
93+
ref_out = dynamic_func(self.x, self.axis)
94+
95+
for ref, actual in zip(ref_out, out):
96+
np.testing.assert_allclose(
97+
ref, actual, rtol=self.tol, atol=self.tol
98+
)
99+
100+
101+
def custom_unsqueeze(x, axis):
102+
return paddle.unsqueeze(x, axis) + 2.0
103+
104+
105+
class TestUnsqueezeFallBack(TestFallBackBase):
106+
def setUp(self):
107+
super().setUp()
108+
self.func_api = custom_unsqueeze
109+
self.x = paddle.rand([5, 10], dtype=self.dtype)
110+
self.axis = paddle.to_tensor([0, 2], dtype=paddle.int64)
111+
112+
def test_fallback(self):
113+
static_func = paddle.jit.to_static(self.func_api, full_graph=True)
114+
dynamic_func = self.func_api
115+
116+
out = static_func(self.x, self.axis)
117+
ref_out = dynamic_func(self.x, self.axis)
118+
119+
for ref, actual in zip(ref_out, out):
120+
np.testing.assert_allclose(
121+
ref, actual, rtol=self.tol, atol=self.tol
122+
)
123+
124+
125+
def custom_any(x, axis):
126+
return paddle.any(x, axis)
127+
128+
129+
class TestAnyFallBack(TestFallBackBase):
130+
def setUp(self):
131+
super().setUp()
132+
self.func_api = custom_any
133+
self.x = paddle.to_tensor([[1, 0], [1, 1]], dtype='int32').cast('bool')
134+
# Axis cannot accept a list of tensors,
135+
# the framework will check the argument type before decomposition.
136+
self.axis = [0]
137+
138+
def test_fallback(self):
139+
static_func = paddle.jit.to_static(self.func_api, full_graph=True)
140+
dynamic_func = self.func_api
141+
142+
out = static_func(self.x, self.axis)
143+
ref_out = dynamic_func(self.x, self.axis)
144+
145+
for ref, actual in zip(ref_out, out):
146+
np.testing.assert_allclose(
147+
ref, actual, rtol=self.tol, atol=self.tol
148+
)
149+
150+
151+
if __name__ == '__main__':
152+
unittest.main()

0 commit comments

Comments
 (0)