Skip to content

Commit 8778957

Browse files
authored
Add element-wise multiplication operator. (#3787)
Add element-wise multiplication operator
1 parent 0f42e56 commit 8778957

File tree

5 files changed

+477
-0
lines changed

5 files changed

+477
-0
lines changed
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/elementwise_mul_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using Tensor = framework::Tensor;
21+
22+
class ElementWiseMulOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
26+
protected:
27+
void InferShape(const framework::InferShapeContext &ctx) const override {
28+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
29+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
30+
auto x_dim = ctx.Input<Tensor>("X")->dims();
31+
auto y_dim = ctx.Input<Tensor>("Y")->dims();
32+
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
33+
"Rank of first input must >= rank of second input.")
34+
ctx.Output<Tensor>("Out")->Resize(x_dim);
35+
}
36+
};
37+
38+
class ElementWiseMulOpMaker : public framework::OpProtoAndCheckerMaker {
39+
public:
40+
ElementWiseMulOpMaker(framework::OpProto *proto,
41+
framework::OpAttrChecker *op_checker)
42+
: OpProtoAndCheckerMaker(proto, op_checker) {
43+
AddInput("X", "The first input of elementwise mul op");
44+
AddInput("Y", "The second input of elementwise mul op");
45+
AddAttr<int>("axis",
46+
R"DOC(
47+
When shape(Y) does not equal shape(X),Y will be broadcasted
48+
to match the shape of X and axis should be dimension index Y in X
49+
)DOC")
50+
.SetDefault(-1)
51+
.EqualGreaterThan(-1);
52+
53+
AddOutput("Out", "The output of elementwise mul op");
54+
AddComment(R"DOC(
55+
Limited elementwise multiple operator.The equation is: Out = X ⊙ Y.
56+
1. The shape of Y should be same with X or
57+
2. Y's shape is a subset of X.
58+
Y will be broadcasted to match the shape of X and axis should be dimension index Y in X.
59+
example:
60+
shape(X) = (2, 3, 4, 5), shape(Y) = (,)
61+
shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
62+
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5)
63+
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
64+
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
65+
)DOC");
66+
}
67+
};
68+
69+
class ElementWiseMulOpGrad : public framework::OperatorWithKernel {
70+
public:
71+
using framework::OperatorWithKernel::OperatorWithKernel;
72+
73+
protected:
74+
void InferShape(const framework::InferShapeContext &ctx) const override {
75+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
76+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
77+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
78+
"Input(Out@GRAD) should not be null");
79+
80+
auto x_dims = ctx.Input<Tensor>("X")->dims();
81+
auto y_dims = ctx.Input<Tensor>("Y")->dims();
82+
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
83+
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
84+
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
85+
86+
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
87+
"Rank of first input must >= rank of second input.")
88+
89+
if (x_grad) {
90+
x_grad->Resize(x_dims);
91+
}
92+
93+
if (y_grad) {
94+
y_grad->Resize(y_dims);
95+
}
96+
}
97+
};
98+
} // namespace operators
99+
} // namespace paddle
100+
101+
namespace ops = paddle::operators;
102+
REGISTER_OP(elementwise_mul, ops::ElementWiseMulOp, ops::ElementWiseMulOpMaker,
103+
elementwise_mul_grad, ops::ElementWiseMulOpGrad);
104+
REGISTER_OP_CPU_KERNEL(
105+
elementwise_mul,
106+
ops::ElementWiseMulKernel<paddle::platform::CPUPlace, float>);
107+
REGISTER_OP_CPU_KERNEL(
108+
elementwise_mul_grad,
109+
ops::ElementWiseMulGradKernel<paddle::platform::CPUPlace, float>);
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/operators/elementwise_mul_op.h"
17+
18+
namespace ops = paddle::operators;
19+
20+
REGISTER_OP_GPU_KERNEL(
21+
elementwise_mul,
22+
ops::ElementWiseMulKernel<paddle::platform::GPUPlace, float>);
23+
REGISTER_OP_GPU_KERNEL(
24+
elementwise_mul_grad,
25+
ops::ElementWiseMulGradKernel<paddle::platform::GPUPlace, float>);
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <iostream>
17+
#include "paddle/framework/eigen.h"
18+
#include "paddle/framework/op_registry.h"
19+
#include "paddle/operators/math/math_function.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
/*
24+
* Out = X ⊙ Y
25+
* 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
26+
* pre=2, n=3*4, post=5
27+
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
28+
* pre=2*3, n=4*5, post=1
29+
*/
30+
31+
inline void get_mid_dims(const framework::DDim& x_dims,
32+
const framework::DDim& y_dims, const int axis,
33+
int& pre, int& n, int& post) {
34+
pre = 1;
35+
n = 1;
36+
post = 1;
37+
for (int i = 0; i < axis; ++i) {
38+
pre *= x_dims[i];
39+
}
40+
41+
for (int i = 0; i < y_dims.size(); ++i) {
42+
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
43+
"Broadcast dimension mismatch.");
44+
n *= y_dims[i];
45+
}
46+
47+
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
48+
post *= x_dims[i];
49+
}
50+
}
51+
52+
template <typename Place, typename T>
53+
class ElementWiseMulKernel : public framework::OpKernel {
54+
public:
55+
void Compute(const framework::ExecutionContext& ctx) const override {
56+
using Tensor = framework::Tensor;
57+
58+
auto* x = ctx.Input<Tensor>("X");
59+
auto* y = ctx.Input<Tensor>("Y");
60+
auto* z = ctx.Output<Tensor>("Out");
61+
z->mutable_data<T>(ctx.GetPlace());
62+
63+
auto x_e = framework::EigenVector<T>::Flatten(*x);
64+
auto y_e = framework::EigenVector<T>::Flatten(*y);
65+
auto z_e = framework::EigenVector<T>::Flatten(*z);
66+
67+
auto x_dims = x->dims();
68+
auto y_dims = y->dims();
69+
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
70+
"Rank of first input must >= rank of second input.")
71+
72+
if (x_dims == y_dims || product(y_dims) == 1) {
73+
z_e.device(ctx.GetEigenDevice<Place>()) = x_e * y_e;
74+
return;
75+
}
76+
77+
int axis = ctx.Attr<int>("axis");
78+
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
79+
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
80+
"Axis should be in range [0, x_dims)");
81+
82+
int pre, n, post;
83+
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
84+
if (post == 1) {
85+
auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))
86+
.broadcast(Eigen::DSizes<int, 2>(pre, 1))
87+
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
88+
z_e.device(ctx.GetEigenDevice<Place>()) = x_e * y_bcast;
89+
return;
90+
} else {
91+
auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))
92+
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post))
93+
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
94+
z_e.device(ctx.GetEigenDevice<Place>()) = x_e * y_bcast;
95+
return;
96+
}
97+
}
98+
};
99+
100+
template <typename Place, typename T>
101+
class ElementWiseMulGradKernel : public framework::OpKernel {
102+
public:
103+
void Compute(const framework::ExecutionContext& ctx) const override {
104+
using Tensor = framework::Tensor;
105+
106+
auto* x = ctx.Input<Tensor>("X");
107+
auto* y = ctx.Input<Tensor>("Y");
108+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
109+
110+
auto x_e = framework::EigenVector<T>::Flatten(*x);
111+
auto y_e = framework::EigenVector<T>::Flatten(*y);
112+
auto dout_e = framework::EigenVector<T>::Flatten(*dout);
113+
114+
auto x_dims = x->dims();
115+
auto y_dims = y->dims();
116+
117+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
118+
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
119+
if (dx) {
120+
dx->mutable_data<T>(ctx.GetPlace());
121+
}
122+
if (dy) {
123+
dy->mutable_data<T>(ctx.GetPlace());
124+
}
125+
126+
if (x_dims == y_dims || product(y_dims) == 1) {
127+
if (dx) {
128+
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
129+
dx_e.device(ctx.GetEigenDevice<Place>()) = dout_e * y_e;
130+
}
131+
132+
if (dy) {
133+
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
134+
dy_e.device(ctx.GetEigenDevice<Place>()) = x_e * dout_e;
135+
}
136+
return;
137+
}
138+
139+
int axis = ctx.Attr<int>("axis");
140+
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
141+
142+
int pre, n, post;
143+
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
144+
145+
// TODO(gongweibao): wrap reshape to a function.
146+
if (post == 1) {
147+
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))
148+
.broadcast(Eigen::DSizes<int, 2>(pre, 1))
149+
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
150+
if (dx) {
151+
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
152+
dx_e.device(ctx.GetEigenDevice<Place>()) = dout_e * y_e_bcast;
153+
}
154+
155+
if (dy) {
156+
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
157+
dy_e.device(ctx.GetEigenDevice<Place>()) =
158+
(x_e * dout_e)
159+
.reshape(Eigen::DSizes<int, 2>(pre, n))
160+
.sum(Eigen::array<int, 1>{{0}});
161+
}
162+
return;
163+
} else {
164+
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))
165+
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post))
166+
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
167+
if (dx) {
168+
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
169+
dx_e.device(ctx.GetEigenDevice<Place>()) = dout_e * y_e_bcast;
170+
}
171+
172+
if (dy) {
173+
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
174+
dy_e.device(ctx.GetEigenDevice<Place>()) =
175+
(x_e * dout_e)
176+
.reshape(Eigen::DSizes<int, 3>(pre, n, post))
177+
.sum(Eigen::array<int, 2>{{0, 2}});
178+
}
179+
return;
180+
}
181+
}
182+
};
183+
184+
} // namespace operators
185+
} // namespace paddle

paddle/pybind/pybind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ USE_OP(add);
3535
USE_OP(onehot_cross_entropy);
3636
USE_OP(sgd);
3737
USE_OP(mul);
38+
USE_OP(elementwise_mul);
3839
USE_OP(mean);
3940
USE_OP(sigmoid);
4041
USE_OP(softmax);

0 commit comments

Comments
 (0)