Skip to content

Commit 13f1641

Browse files
authored
move elementwise_mul selected rows input (#41042)
1 parent 04325d2 commit 13f1641

File tree

10 files changed

+161
-205
lines changed

10 files changed

+161
-205
lines changed

paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include "paddle/phi/core/kernel_registry.h"
2424

2525
USE_OP_ITSELF(scale);
26-
USE_OP(elementwise_mul);
26+
USE_OP_ITSELF(elementwise_mul);
2727
USE_OP_ITSELF(elementwise_add);
2828
USE_OP_ITSELF(elementwise_add_grad);
2929

paddle/fluid/inference/tensorrt/convert/test_elementwise_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,4 @@ TEST(elementwise_op, plugin) {
104104
} // namespace paddle
105105

106106
USE_OP_ITSELF(elementwise_add);
107-
USE_OP(elementwise_mul);
107+
USE_OP_ITSELF(elementwise_mul);

paddle/fluid/operators/elementwise/elementwise_mul_op.cc

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,6 @@ limitations under the License. */
2020

2121
namespace paddle {
2222
namespace operators {
23-
24-
template <typename T>
25-
struct SameDimsElemwiseMul<
26-
platform::CPUDeviceContext, T,
27-
typename std::enable_if<std::is_floating_point<T>::value>::type> {
28-
void operator()(const framework::ExecutionContext &ctx,
29-
const framework::Tensor *x, const framework::Tensor *y,
30-
framework::Tensor *z) {
31-
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(ctx);
32-
blas.VMUL(x->numel(), x->data<T>(), y->data<T>(), z->data<T>());
33-
}
34-
};
35-
36-
template <typename T>
37-
struct SameDimsElemwiseMul<
38-
platform::CPUDeviceContext, T,
39-
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
40-
void operator()(const framework::ExecutionContext &ctx,
41-
const framework::Tensor *x, const framework::Tensor *y,
42-
framework::Tensor *z) {
43-
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
44-
auto eigen_y = framework::EigenVector<T>::Flatten(*y);
45-
auto eigen_z = framework::EigenVector<T>::Flatten(*z);
46-
auto &place = *ctx.template device_context<platform::CPUDeviceContext>()
47-
.eigen_device();
48-
eigen_z.device(place) = eigen_x * eigen_y;
49-
}
50-
};
51-
5223
class ElementwiseMulOpMaker : public ElementwiseOpMaker {
5324
protected:
5425
std::string GetName() const override { return "Mul"; }
@@ -160,20 +131,6 @@ REGISTER_OPERATOR(
160131

161132
REGISTER_OPERATOR(elementwise_mul_triple_grad, ops::ElementwiseOpTripleGrad);
162133

163-
REGISTER_OP_CPU_KERNEL(
164-
elementwise_mul,
165-
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>,
166-
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, double>,
167-
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
168-
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>,
169-
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, bool>,
170-
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
171-
paddle::platform::bfloat16>,
172-
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
173-
paddle::platform::complex<float>>,
174-
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
175-
paddle::platform::complex<double>>);
176-
177134
REGISTER_OP_VERSION(elementwise_mul)
178135
.AddCheckpoint(
179136
R"ROC(Register elementwise_mul for adding the attribute of Scale_y)ROC",

paddle/fluid/operators/elementwise/elementwise_mul_op.cu

Lines changed: 0 additions & 78 deletions
This file was deleted.

paddle/fluid/operators/elementwise/elementwise_mul_op.h

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -58,85 +58,5 @@ class ElementwiseMulOp : public ElementwiseOp {
5858
}
5959
};
6060

61-
template <typename DeviceContext, typename T>
62-
void default_elementwise_mul(const framework::ExecutionContext& ctx,
63-
const framework::Tensor* x,
64-
const framework::Tensor* y, framework::Tensor* z) {
65-
int axis = ctx.Attr<int>("axis");
66-
auto x_dims = x->dims();
67-
auto y_dims = y->dims();
68-
if (x_dims.size() >= y_dims.size()) {
69-
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
70-
MulFunctor<T>(), z);
71-
} else {
72-
ElementwiseComputeEx<InverseMulFunctor<T>, DeviceContext, T>(
73-
ctx, x, y, axis, InverseMulFunctor<T>(), z);
74-
}
75-
}
76-
77-
template <typename DeviceContext, typename T, class Enable = void>
78-
struct SameDimsElemwiseMul {
79-
void operator()(const framework::ExecutionContext& ctx,
80-
const framework::Tensor* x, const framework::Tensor* y,
81-
framework::Tensor* z);
82-
};
83-
84-
template <typename DeviceContext, typename T>
85-
class ElementwiseMulKernel : public framework::OpKernel<T> {
86-
public:
87-
void Compute(const framework::ExecutionContext& ctx) const override {
88-
auto x_var = ctx.InputVar("X");
89-
PADDLE_ENFORCE_EQ(x_var != nullptr, true,
90-
platform::errors::InvalidArgument(
91-
"Cannot get input Variable X, Variable name = %s.",
92-
ctx.InputName("X")));
93-
auto* y = ctx.Input<framework::LoDTensor>("Y");
94-
95-
framework::Tensor x, *z;
96-
if (x_var->IsType<phi::SelectedRows>()) {
97-
PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true,
98-
platform::errors::InvalidArgument(
99-
"For elementwise_op, if X is Sparse, Y must be "
100-
"scalar. But reveived the size of Y = %s.",
101-
y->dims().size()));
102-
auto& x_sele = x_var->Get<phi::SelectedRows>();
103-
auto out_sele = ctx.Output<phi::SelectedRows>("Out");
104-
x = x_sele.value();
105-
out_sele->set_rows(x_sele.rows());
106-
out_sele->set_height(x_sele.height());
107-
out_sele->mutable_value()->Resize(x_sele.value().dims());
108-
out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type());
109-
z = ctx.Output<phi::SelectedRows>("Out")->mutable_value();
110-
z->mutable_data<T>(ctx.GetPlace());
111-
auto dims_equal = x.dims() == y->dims();
112-
if (dims_equal) {
113-
SameDimsElemwiseMul<DeviceContext, T> same_dims_mul;
114-
same_dims_mul(ctx, &x, y, z);
115-
} else {
116-
default_elementwise_mul<DeviceContext, T>(ctx, &x, y, z);
117-
}
118-
} else if (x_var->IsType<framework::LoDTensor>()) {
119-
auto* x_lod = ctx.Input<framework::LoDTensor>("X");
120-
auto* z_lod = ctx.Output<framework::LoDTensor>("Out");
121-
z_lod->mutable_data<T>(ctx.GetPlace());
122-
123-
auto& dev_ctx = ctx.device_context<DeviceContext>();
124-
int axis = ctx.Attr<int>("axis");
125-
auto pt_x = paddle::experimental::MakePhiDenseTensor(*x_lod);
126-
auto pt_y = paddle::experimental::MakePhiDenseTensor(*y);
127-
auto pt_z = paddle::experimental::MakePhiDenseTensor(*z_lod);
128-
phi::MultiplyRawKernel<T>(
129-
static_cast<const typename framework::ConvertToPhiContext<
130-
DeviceContext>::TYPE&>(dev_ctx),
131-
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
132-
} else {
133-
PADDLE_THROW(platform::errors::InvalidArgument(
134-
"X's type[%s] is not supported by elementwise_op. X's type should be "
135-
"LoDTensor or SelectedRows.",
136-
framework::ToTypeName(x_var->Type())));
137-
}
138-
}
139-
};
140-
14161
} // namespace operators
14262
} // namespace paddle

paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
USE_OP_ITSELF(elementwise_add);
2929
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
30-
USE_OP(elementwise_mul);
30+
USE_OP_ITSELF(elementwise_mul);
3131
USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN);
3232
USE_OP_ITSELF(relu);
3333
USE_OP_DEVICE_KERNEL(relu, MKLDNN);

paddle/phi/kernels/elementwise_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ PD_REGISTER_KERNEL(multiply,
202202
int64_t,
203203
bool,
204204
phi::dtype::float16,
205+
phi::dtype::bfloat16,
205206
complex64,
206207
complex128) {}
207208
PD_REGISTER_KERNEL(maximum,
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/kernels/selected_rows/elementwise_kernel.h"
16+
17+
#include "paddle/phi/backends/cpu/cpu_context.h"
18+
#include "paddle/phi/common/bfloat16.h"
19+
#include "paddle/phi/common/complex.h"
20+
#include "paddle/phi/common/float16.h"
21+
#include "paddle/phi/core/enforce.h"
22+
#include "paddle/phi/core/kernel_registry.h"
23+
#include "paddle/phi/kernels/elementwise_kernel.h"
24+
25+
namespace phi {
26+
namespace sr {
27+
28+
template <typename T, typename Context>
29+
void MultiplyRawKernel(const Context& dev_ctx,
30+
const SelectedRows& x,
31+
const DenseTensor& y,
32+
int axis,
33+
SelectedRows* out) {
34+
PADDLE_ENFORCE_EQ(y.dims().size() == 1 && y.dims()[0] == 1,
35+
true,
36+
phi::errors::InvalidArgument(
37+
"For MultiplyKernel, if X is Sparse, Y must be "
38+
"scalar. But reveived the size of Y = %s.",
39+
y.dims().size()));
40+
out->set_rows(x.rows());
41+
out->set_height(x.height());
42+
auto z = out->mutable_value();
43+
z->Resize(x.value().dims());
44+
dev_ctx.Alloc(z, x.value().dtype());
45+
MultiplyRawKernel<T, Context>(dev_ctx, x.value(), y, axis, z);
46+
}
47+
48+
template <typename T, typename Context>
49+
void MultiplyKernel(const Context& dev_ctx,
50+
const SelectedRows& x,
51+
const DenseTensor& y,
52+
SelectedRows* out) {
53+
int axis = -1;
54+
MultiplyRawKernel<T, Context>(dev_ctx, x, y, axis, out);
55+
}
56+
57+
} // namespace sr
58+
} // namespace phi
59+
60+
using complex64 = ::phi::dtype::complex<float>;
61+
using complex128 = ::phi::dtype::complex<double>;
62+
63+
PD_REGISTER_KERNEL(multiply_raw_sr,
64+
CPU,
65+
ALL_LAYOUT,
66+
phi::sr::MultiplyRawKernel,
67+
float,
68+
double,
69+
int,
70+
int64_t,
71+
bool,
72+
phi::dtype::bfloat16,
73+
complex64,
74+
complex128) {}
75+
PD_REGISTER_KERNEL(multiply_sr,
76+
CPU,
77+
ALL_LAYOUT,
78+
phi::sr::MultiplyKernel,
79+
float,
80+
double,
81+
int,
82+
int64_t,
83+
bool,
84+
phi::dtype::bfloat16,
85+
complex64,
86+
complex128) {}
87+
88+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
89+
PD_REGISTER_KERNEL(multiply_raw_sr,
90+
GPU,
91+
ALL_LAYOUT,
92+
phi::sr::MultiplyRawKernel,
93+
float,
94+
double,
95+
int,
96+
int64_t,
97+
bool,
98+
phi::dtype::bfloat16,
99+
phi::dtype::float16,
100+
complex64,
101+
complex128) {}
102+
PD_REGISTER_KERNEL(multiply_sr,
103+
GPU,
104+
ALL_LAYOUT,
105+
phi::sr::MultiplyKernel,
106+
float,
107+
double,
108+
int,
109+
int64_t,
110+
bool,
111+
phi::dtype::bfloat16,
112+
phi::dtype::float16,
113+
complex64,
114+
complex128) {}
115+
#endif

0 commit comments

Comments
 (0)