Skip to content

Commit 8d002ab

Browse files
zeroRainslixcli
authored andcommitted
[Prim][PIR] Forward decomposite the lerp op (PaddlePaddle#65967)
* forward decomposite the lerp op * lerp * fix the bug in the get_output_dims * polish * fix code style * move the modify to infermeta * fix the bug * fix the bug
1 parent 1e158bb commit 8d002ab

File tree

5 files changed

+221
-6
lines changed

5 files changed

+221
-6
lines changed

paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"instance_norm",
4242
"layer_norm",
4343
"leaky_relu",
44+
"lerp",
4445
"log_loss",
4546
"log_softmax",
4647
"mean",
@@ -87,6 +88,7 @@
8788
"instance_norm",
8889
"layer_norm",
8990
"leaky_relu",
91+
"lerp",
9092
"log_loss",
9193
"log_softmax",
9294
"mean",

paddle/fluid/primitive/composite/composite.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,32 @@ Tensor elu_decomp(const Tensor& x, const float alpha) {
15141514
}
15151515
}
15161516

1517+
template <typename T>
1518+
Tensor lerp_decomp(const Tensor& x, const Tensor& y, const Tensor& weight) {
1519+
Tensor x_cast = x;
1520+
Tensor y_cast = y;
1521+
Tensor weight_cast = weight;
1522+
bool need_cast = false;
1523+
if (is_half_dtype(x.dtype())) {
1524+
need_cast = true;
1525+
x_cast = cast<T>(x, DataType::FLOAT32);
1526+
}
1527+
if (is_half_dtype(y.dtype())) {
1528+
need_cast = true;
1529+
y_cast = cast<T>(y, DataType::FLOAT32);
1530+
}
1531+
if (is_half_dtype(weight.dtype())) {
1532+
need_cast = true;
1533+
weight_cast = cast<T>(weight, DataType::FLOAT32);
1534+
}
1535+
Tensor res = x_cast + weight_cast * (y_cast - x_cast);
1536+
if (need_cast) {
1537+
return cast<T>(res, x.dtype());
1538+
} else {
1539+
return res;
1540+
}
1541+
}
1542+
15171543
template <typename T>
15181544
Tensor log_loss_decomp(const Tensor& input,
15191545
const Tensor& label,

paddle/phi/infermeta/ternary.cc

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,9 +1142,66 @@ void LerpInferMeta(const MetaTensor& x,
11421142
auto x_dims = x.dims();
11431143
auto y_dims = y.dims();
11441144
auto w_dims = weight.dims();
1145-
DDim out_dims;
1146-
out_dims = funcs::GetOutputDims(x_dims, y_dims);
1147-
out_dims = funcs::GetOutputDims(out_dims, w_dims);
1145+
DDim l_dims, s_dims;
1146+
if (x_dims.size() > y_dims.size()) {
1147+
l_dims = x_dims;
1148+
s_dims = y_dims;
1149+
} else {
1150+
l_dims = y_dims;
1151+
s_dims = x_dims;
1152+
}
1153+
std::vector<int64_t> shapes = common::vectorize<int64_t>(l_dims);
1154+
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
1155+
int64_t s = s_dims[i];
1156+
int64_t l = l_dims[j];
1157+
if (s != l) {
1158+
if (l == 1) {
1159+
shapes[j] = s;
1160+
} else if (s == 1 || s == -1) {
1161+
shapes[j] = l;
1162+
} else if (l == -1) {
1163+
shapes[j] = s;
1164+
} else {
1165+
PADDLE_THROW(errors::InvalidArgument(
1166+
"The shape of tensor a %s:%d must match shape of tensor b "
1167+
"%s:%d.",
1168+
s_dims.to_str(),
1169+
i,
1170+
l_dims.to_str(),
1171+
j));
1172+
}
1173+
}
1174+
}
1175+
if (static_cast<int>(shapes.size()) > w_dims.size()) {
1176+
l_dims = common::make_ddim(shapes);
1177+
s_dims = w_dims;
1178+
} else {
1179+
l_dims = w_dims;
1180+
s_dims = common::make_ddim(shapes);
1181+
}
1182+
std::vector<int64_t> shapes_out = common::vectorize<int64_t>(l_dims);
1183+
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
1184+
int64_t s = s_dims[i];
1185+
int64_t l = l_dims[j];
1186+
if (s != l) {
1187+
if (l == 1) {
1188+
shapes_out[j] = s;
1189+
} else if (s == 1 || s == -1) {
1190+
shapes_out[j] = l;
1191+
} else if (l == -1) {
1192+
shapes_out[j] = s;
1193+
} else {
1194+
PADDLE_THROW(errors::InvalidArgument(
1195+
"The shape of tensor a %s:%d must match shape of tensor b "
1196+
"%s:%d.",
1197+
s_dims.to_str(),
1198+
i,
1199+
l_dims.to_str(),
1200+
j));
1201+
}
1202+
}
1203+
}
1204+
DDim out_dims = common::make_ddim(shapes_out);
11481205
out->set_dims(out_dims);
11491206
out->set_dtype(x.dtype());
11501207
out->share_lod(x);

test/legacy_test/test_lerp_op.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class TestLerp(OpTest):
2929
def setUp(self):
3030
self.op_type = "lerp"
3131
self.python_api = paddle.lerp
32+
self.prim_op_type = "comp"
33+
self.public_python_api = paddle.lerp
3234
self.init_dtype()
3335
self.init_shape()
3436
self.init_xyshape()
@@ -53,10 +55,10 @@ def init_wshape(self):
5355
self.wshape = [1]
5456

5557
def test_check_output(self):
56-
self.check_output(check_pir=True)
58+
self.check_output(check_pir=True, check_prim_pir=True)
5759

5860
def test_check_grad(self):
59-
self.check_grad(['X', 'Y'], 'Out', check_pir=True)
61+
self.check_grad(['X', 'Y'], 'Out', check_pir=True, check_prim_pir=True)
6062

6163

6264
class TestLerpWithDim2(TestLerp):
@@ -231,6 +233,8 @@ class TestLerpBF16(TestLerp):
231233
def setUp(self):
232234
self.op_type = "lerp"
233235
self.python_api = paddle.lerp
236+
self.prim_op_type = "comp"
237+
self.public_python_api = paddle.lerp
234238
self.dtype = np.uint16
235239
self.init_shape()
236240
self.init_xyshape()
@@ -270,7 +274,7 @@ def init_grad(self, w):
270274

271275
def test_check_output(self):
272276
place = core.CUDAPlace(0)
273-
self.check_output_with_place(place, check_pir=True)
277+
self.check_output_with_place(place, check_pir=True, check_prim_pir=True)
274278

275279
def test_check_grad(self):
276280
place = core.CUDAPlace(0)
@@ -280,6 +284,7 @@ def test_check_grad(self):
280284
'Out',
281285
user_defined_grads=[self.x_grad, self.y_grad],
282286
check_pir=True,
287+
check_prim_pir=True,
283288
)
284289

285290

test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def mean_all_net1(x):
110110
return paddle._C_ops.mean_all(x)
111111

112112

113+
def lerp_net(x, y, weight):
114+
return paddle.lerp(x, y, weight)
115+
116+
113117
group_norm1 = paddle.nn.GroupNorm(num_channels=128, num_groups=32)
114118

115119

@@ -649,6 +653,127 @@ def setUp(self):
649653
self.enable_cinn = False
650654

651655

656+
class TestPrimThree(unittest.TestCase):
657+
def setUp(self):
658+
np.random.seed(2023)
659+
self.shape_x = [300, 2048]
660+
self.shape_y = [300, 2048]
661+
self.shape_z = [1]
662+
self.dtype_x = "float32"
663+
self.dtype_y = "float32"
664+
self.dtype_z = "float32"
665+
self.init_x_shape = [None, 2048]
666+
self.init_y_shape = [None, 2048]
667+
self.init_z_shape = [None]
668+
self.x = np.random.random(self.shape_x).astype(self.dtype_x)
669+
self.y = np.random.random(self.shape_y).astype(self.dtype_y)
670+
self.z = np.random.random(self.shape_z).astype(self.dtype_z)
671+
self.net = lerp_net
672+
self.necessary_ops = "pd_op.lerp"
673+
self.enable_cinn = False
674+
self.tol = 1e-6
675+
676+
def base_net(self, flag=None):
677+
x = paddle.to_tensor(self.x)
678+
y = paddle.to_tensor(self.y)
679+
z = paddle.to_tensor(self.z)
680+
if flag == "prim":
681+
core._set_prim_all_enabled(True)
682+
fn = apply_to_static(
683+
self.net,
684+
use_cinn=self.enable_cinn,
685+
input_spec=[
686+
InputSpec(shape=self.init_x_shape, dtype=self.dtype_x),
687+
InputSpec(shape=self.init_y_shape, dtype=self.dtype_y),
688+
InputSpec(shape=self.init_z_shape, dtype=self.dtype_z),
689+
],
690+
)
691+
fn.eval()
692+
else:
693+
fn = self.net
694+
res = fn(x, y, z)
695+
696+
if flag == "prim":
697+
ops = [
698+
op.name()
699+
for op in fn.program_cache.last()[-1][-1]
700+
.infer_program.program.global_block()
701+
.ops
702+
]
703+
assert self.necessary_ops not in ops
704+
core._set_prim_all_enabled(False)
705+
return res
706+
707+
def test_prim_all_dynamic(self):
708+
res_ref = self.base_net()
709+
res = self.base_net("prim")
710+
for ref, actual in zip(res_ref, res):
711+
np.testing.assert_allclose(ref, actual, rtol=self.tol)
712+
713+
714+
class TestPrimLerp1(TestPrimThree):
715+
def setUp(self):
716+
np.random.seed(2023)
717+
self.shape_x = [10, 1, 10, 5, 5]
718+
self.shape_y = [10, 5, 1, 5, 5]
719+
self.shape_z = [1]
720+
self.dtype_x = "float32"
721+
self.dtype_y = "float32"
722+
self.dtype_z = "float32"
723+
self.init_x_shape = [None, None, None, 5, 5]
724+
self.init_y_shape = [None, None, None, 5, 5]
725+
self.init_z_shape = [None]
726+
self.x = np.random.random(self.shape_x).astype(self.dtype_x)
727+
self.y = np.random.random(self.shape_y).astype(self.dtype_y)
728+
self.z = np.random.random(self.shape_z).astype(self.dtype_z)
729+
self.net = lerp_net
730+
self.necessary_ops = "pd_op.lerp"
731+
self.enable_cinn = False
732+
self.tol = 1e-5
733+
734+
735+
class TestPrimLerp2(TestPrimThree):
736+
def setUp(self):
737+
np.random.seed(2023)
738+
self.shape_x = [10, 10, 5, 5]
739+
self.shape_y = [10, 10, 5, 5]
740+
self.shape_z = [5]
741+
self.dtype_x = "float32"
742+
self.dtype_y = "float32"
743+
self.dtype_z = "float32"
744+
self.init_x_shape = [None, None, 5, 5]
745+
self.init_y_shape = [None, None, 5, 5]
746+
self.init_z_shape = [None]
747+
self.x = np.random.random(self.shape_x).astype(self.dtype_x)
748+
self.y = np.random.random(self.shape_y).astype(self.dtype_y)
749+
self.z = np.random.random(self.shape_z).astype(self.dtype_z)
750+
self.net = lerp_net
751+
self.necessary_ops = "pd_op.lerp"
752+
self.enable_cinn = False
753+
self.tol = 1e-6
754+
755+
756+
class TestPrimLerp3(TestPrimThree):
757+
def setUp(self):
758+
np.random.seed(2023)
759+
self.shape_x = [10, 5, 10, 1, 5]
760+
self.shape_y = [10, 5, 10, 5, 1]
761+
self.shape_z = [1]
762+
self.dtype_x = "float32"
763+
self.dtype_y = "float32"
764+
self.dtype_z = "float32"
765+
self.init_x_shape = [None, None, None, 1, 5]
766+
self.init_y_shape = [None, None, None, 5, 1]
767+
self.init_z_shape = [None]
768+
self.x = np.random.random(self.shape_x).astype(self.dtype_x)
769+
self.y = np.random.random(self.shape_y).astype(self.dtype_y)
770+
self.z = np.random.random(self.shape_z).astype(self.dtype_z)
771+
self.net = lerp_net
772+
self.necessary_ops = "pd_op.lerp"
773+
self.enable_cinn = False
774+
self.tol = 1e-5
775+
776+
652777
class TestPrimLogLoss1(TestPrimTwo):
653778
def setUp(self):
654779
np.random.seed(2023)

0 commit comments

Comments
 (0)