Skip to content

Commit 770ce7c

Browse files
authored
xpu mul unittest *test=kunlun (#41140)
1 parent 1ed1a97 commit 770ce7c

File tree

3 files changed

+143
-113
lines changed

3 files changed

+143
-113
lines changed

paddle/fluid/operators/mul_op_xpu.cc

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ limitations under the License. */
1919
#include <unordered_map>
2020
#include <vector>
2121
#include "paddle/fluid/framework/op_registry.h"
22+
#include "paddle/fluid/operators/xpu_api_wrapper.h"
23+
#include "paddle/fluid/platform/device/device_wrapper.h"
2224

2325
namespace paddle {
2426
namespace operators {
@@ -28,6 +30,8 @@ using framework::Tensor;
2830

2931
template <typename DeviceContext, typename T>
3032
class MulXPUKernel : public framework::OpKernel<T> {
33+
using XPUType = typename XPUTypeTrait<T>::Type;
34+
3135
public:
3236
void Compute(const framework::ExecutionContext& context) const override {
3337
const Tensor* x = context.Input<Tensor>("X");
@@ -62,14 +66,15 @@ class MulXPUKernel : public framework::OpKernel<T> {
6266
const T* data_b = y_matrix.data<T>();
6367
T* data_c = z->data<T>();
6468
auto& dev_ctx = context.template device_context<DeviceContext>();
65-
int ret = xpu::fc_int16(dev_ctx.x_context(), trans_a, trans_b, m, n, k,
66-
alpha, data_a, data_b, beta, data_c);
67-
PADDLE_ENFORCE_EQ(
68-
ret, XPU_SUCCESS,
69-
platform::errors::External(
70-
"XPU API return wrong value[%d], please check whether "
71-
"Baidu Kunlun Card is properly installed.",
72-
ret));
69+
70+
int ret = xpu_fc_wrapper<XPUType, int16_t>(
71+
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(data_a),
72+
reinterpret_cast<const XPUType*>(data_b),
73+
reinterpret_cast<XPUType*>(data_c), m, n, k, trans_a, trans_b, nullptr,
74+
nullptr, nullptr, k, n, n, alpha, beta, nullptr,
75+
xpu::Activation_t::LINEAR);
76+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper");
77+
7378
if (z_dim.size() != 2) {
7479
z->Resize(z_dim);
7580
}
@@ -78,6 +83,8 @@ class MulXPUKernel : public framework::OpKernel<T> {
7883

7984
template <typename DeviceContext, typename T>
8085
class MulGradXPUKernel : public framework::OpKernel<T> {
86+
using XPUType = typename XPUTypeTrait<T>::Type;
87+
8188
public:
8289
void Compute(const framework::ExecutionContext& ctx) const override {
8390
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
@@ -126,14 +133,14 @@ class MulGradXPUKernel : public framework::OpKernel<T> {
126133
const T* data_a = dout->data<T>();
127134
const T* data_b = y_matrix.data<T>();
128135
T* data_c = dx_matrix.data<T>();
129-
int ret =
130-
xpu::gemm_int16(dev_ctx.x_context(), trans_a, trans_b, m, n, k, alpha,
131-
data_a, lda, data_b, ldb, beta, data_c, ldc);
132-
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
133-
platform::errors::External(
134-
"XPU API return wrong value[%d], please check "
135-
"where Baidu Kunlun Card is properly installed.",
136-
ret));
136+
137+
int ret = xpu_fc_wrapper<XPUType, int16_t>(
138+
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(data_a),
139+
reinterpret_cast<const XPUType*>(data_b),
140+
reinterpret_cast<XPUType*>(data_c), m, n, k, trans_a, trans_b,
141+
nullptr, nullptr, nullptr, lda, ldb, ldc, alpha, beta, nullptr,
142+
xpu::Activation_t::LINEAR);
143+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper");
137144
}
138145

139146
if (dy) {
@@ -159,14 +166,14 @@ class MulGradXPUKernel : public framework::OpKernel<T> {
159166
const T* data_a = x_matrix.data<T>();
160167
const T* data_b = dout->data<T>();
161168
T* data_c = dy_matrix.data<T>();
162-
int ret =
163-
xpu::gemm_int16(dev_ctx.x_context(), trans_a, trans_b, m, n, k, alpha,
164-
data_a, lda, data_b, ldb, beta, data_c, ldc);
165-
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
166-
platform::errors::External(
167-
"XPU API return wrong value[%d], please check "
168-
"where Baidu Kunlun Card is properly installed.",
169-
ret));
169+
170+
int ret = xpu_fc_wrapper<XPUType, int16_t>(
171+
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(data_a),
172+
reinterpret_cast<const XPUType*>(data_b),
173+
reinterpret_cast<XPUType*>(data_c), m, n, k, trans_a, trans_b,
174+
nullptr, nullptr, nullptr, lda, ldb, ldc, alpha, beta, nullptr,
175+
xpu::Activation_t::LINEAR);
176+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu_fc_wrapper");
170177
}
171178
}
172179
};
@@ -175,9 +182,12 @@ class MulGradXPUKernel : public framework::OpKernel<T> {
175182
} // namespace paddle
176183

177184
namespace ops = paddle::operators;
185+
namespace plat = paddle::platform;
178186

179187
REGISTER_OP_XPU_KERNEL(
180-
mul, ops::MulXPUKernel<paddle::platform::XPUDeviceContext, float>);
188+
mul, ops::MulXPUKernel<paddle::platform::XPUDeviceContext, float>,
189+
ops::MulXPUKernel<paddle::platform::XPUDeviceContext, plat::float16>);
181190
REGISTER_OP_XPU_KERNEL(
182-
mul_grad, ops::MulGradXPUKernel<paddle::platform::XPUDeviceContext, float>)
191+
mul_grad, ops::MulGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
192+
ops::MulGradXPUKernel<paddle::platform::XPUDeviceContext, plat::float16>)
183193
#endif

paddle/fluid/platform/device/xpu/xpu2_op_list.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,10 @@ XPUOpMap& get_kl2_ops() {
7070
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
7171
pOpKernelType(vartype::FP16, XPUPlace())})},
7272
{"dropout_grad",
73-
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
74-
{"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
73+
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
74+
pOpKernelType(vartype::FP16, XPUPlace())})},
75+
{"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
76+
pOpKernelType(vartype::FP16, XPUPlace())})},
7577
{"elementwise_add_grad",
7678
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
7779
pOpKernelType(vartype::FP16, XPUPlace())})},
@@ -249,6 +251,8 @@ XPUOpMap& get_kl2_ops() {
249251
{"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
250252
{"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
251253
pOpKernelType(vartype::FP16, XPUPlace())})},
254+
{"mul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
255+
pOpKernelType(vartype::FP16, XPUPlace())})},
252256
{"nearest_interp_v2",
253257
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
254258
{"nearest_interp_v2_grad",

python/paddle/fluid/tests/unittests/xpu/test_mul_op_xpu.py

Lines changed: 101 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -27,104 +27,120 @@
2727

2828
paddle.enable_static()
2929

30+
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
31+
3032

31-
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
32-
"core is not compiled with XPU")
3333
class TestMulOpError(unittest.TestCase):
3434
def test_errors(self):
3535
with program_guard(Program(), Program()):
3636
# The input type of mul_op must be Variable.
3737
x1 = fluid.create_lod_tensor(
38-
np.array([[-1]]), [[1]], fluid.CPUPlace())
38+
np.array([[-1]]), [[1]], fluid.XPUPlace(0))
3939
x2 = fluid.create_lod_tensor(
40-
np.array([[-1]]), [[1]], fluid.CPUPlace())
40+
np.array([[-1]]), [[1]], fluid.XPUPlace(0))
4141
self.assertRaises(TypeError, fluid.layers.mul, x1, x2)
42-
# The input dtype of mul_op must be float32 or float64.
42+
# The input dtype of mul_op must be float32.
4343
x3 = fluid.layers.data(name='x3', shape=[4], dtype="int32")
4444
x4 = fluid.layers.data(name='x4', shape=[4], dtype="int32")
4545
self.assertRaises(TypeError, fluid.layers.mul, x3, x4)
4646

4747

48-
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
49-
"core is not compiled with XPU")
50-
class TestXPUMulOp1(XPUOpTest):
51-
def setUp(self):
52-
self.op_type = "mul"
53-
self.dtype = np.float32
54-
self.use_xpu = True
55-
self.init_dtype_type()
56-
self.inputs = {
57-
'X': np.random.random((3, 4, 2, 9)).astype(self.dtype),
58-
'Y': np.random.random((3, 6, 1, 2, 3)).astype(self.dtype)
59-
}
60-
self.attrs = {
61-
'x_num_col_dims': 2,
62-
'y_num_col_dims': 2,
63-
}
64-
result = np.dot(self.inputs['X'].reshape(3 * 4, 2 * 9),
65-
self.inputs['Y'].reshape(3 * 6, 1 * 2 * 3))
66-
result = result.reshape(3, 4, 1, 2, 3)
67-
self.outputs = {'Out': result}
68-
69-
def init_dtype_type(self):
70-
pass
71-
72-
def test_check_output(self):
73-
place = paddle.XPUPlace(0)
74-
self.check_output_with_place(place, atol=0.01)
75-
76-
def test_check_grad_normal(self):
77-
place = paddle.XPUPlace(0)
78-
self.check_grad_with_place(
79-
place, ['X', 'Y'], 'Out', max_relative_error=0.1)
80-
81-
def test_check_grad_ingore_x(self):
82-
place = paddle.XPUPlace(0)
83-
self.check_grad_with_place(
84-
place, ['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X"))
85-
86-
def test_check_grad_ignore_y(self):
87-
place = paddle.XPUPlace(0)
88-
self.check_grad_with_place(
89-
place, ['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y'))
90-
91-
92-
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
93-
"core is not compiled with XPU")
94-
class TestXPUMulOp2(XPUOpTest):
95-
def setUp(self):
96-
self.op_type = "mul"
97-
self.use_xpu = True
98-
self.dtype = np.float32
99-
self.init_dtype_type()
100-
self.inputs = {
101-
'X': np.random.random((20, 5)).astype(self.dtype),
102-
'Y': np.random.random((5, 21)).astype(self.dtype)
103-
}
104-
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
105-
106-
def init_dtype_type(self):
107-
self.dtype = np.float32
108-
109-
def test_check_output(self):
110-
place = paddle.XPUPlace(0)
111-
self.check_output_with_place(place, atol=0.01)
112-
113-
def test_check_grad_normal(self):
114-
place = paddle.XPUPlace(0)
115-
self.check_grad_with_place(
116-
place, ['X', 'Y'], 'Out', max_relative_error=0.1)
117-
118-
def test_check_grad_ingore_x(self):
119-
place = paddle.XPUPlace(0)
120-
self.check_grad_with_place(
121-
place, ['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X"))
122-
123-
def test_check_grad_ingore_y(self):
124-
place = paddle.XPUPlace(0)
125-
self.check_grad_with_place(
126-
place, ['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y'))
127-
48+
class XPUTestMulOp(XPUOpTestWrapper):
49+
def __init__(self):
50+
self.op_name = 'mul'
51+
self.use_dynamic_create_class = False
52+
53+
class TestXPUMulOp1(XPUOpTest):
54+
def setUp(self):
55+
self.op_type = "mul"
56+
self.dtype = self.in_type
57+
self.inputs = {
58+
'X': np.random.random((3, 4, 2, 9)).astype(self.in_type_str),
59+
'Y': np.random.random((3, 6, 1, 2, 3)).astype(self.in_type_str)
60+
}
61+
self.attrs = {
62+
'x_num_col_dims': 2,
63+
'y_num_col_dims': 2,
64+
}
65+
result = np.dot(self.inputs['X'].reshape(3 * 4, 2 * 9),
66+
self.inputs['Y'].reshape(3 * 6, 1 * 2 * 3))
67+
result = result.reshape(3, 4, 1, 2, 3)
68+
self.outputs = {'Out': result}
69+
70+
def test_check_output(self):
71+
paddle.enable_static()
72+
place = paddle.XPUPlace(0)
73+
self.check_output_with_place(place, atol=0.01)
74+
75+
def test_check_grad_normal(self):
76+
place = paddle.XPUPlace(0)
77+
paddle.enable_static()
78+
self.check_grad_with_place(
79+
place, ['X', 'Y'], 'Out', max_relative_error=0.1)
80+
81+
def test_check_grad_ingore_x(self):
82+
place = paddle.XPUPlace(0)
83+
paddle.enable_static()
84+
self.check_grad_with_place(
85+
place, ['Y'],
86+
'Out',
87+
max_relative_error=0.1,
88+
no_grad_set=set("X"))
89+
90+
def test_check_grad_ignore_y(self):
91+
place = paddle.XPUPlace(0)
92+
paddle.enable_static()
93+
self.check_grad_with_place(
94+
place, ['X'],
95+
'Out',
96+
max_relative_error=0.1,
97+
no_grad_set=set('Y'))
98+
99+
class TestXPUMulOp2(XPUOpTest):
100+
def setUp(self):
101+
self.op_type = "mul"
102+
self.use_xpu = True
103+
self.dtype = self.in_type
104+
self.inputs = {
105+
'X': np.random.random((20, 5)).astype(self.in_type_str),
106+
'Y': np.random.random((5, 21)).astype(self.in_type_str)
107+
}
108+
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
109+
110+
def test_check_output(self):
111+
place = paddle.XPUPlace(0)
112+
paddle.enable_static()
113+
self.check_output_with_place(place, atol=0.01)
114+
115+
def test_check_grad_normal(self):
116+
place = paddle.XPUPlace(0)
117+
paddle.enable_static()
118+
self.check_grad_with_place(
119+
place, ['X', 'Y'], 'Out', max_relative_error=0.1)
120+
121+
def test_check_grad_ingore_x(self):
122+
place = paddle.XPUPlace(0)
123+
paddle.enable_static()
124+
self.check_grad_with_place(
125+
place, ['Y'],
126+
'Out',
127+
max_relative_error=0.1,
128+
no_grad_set=set("X"))
129+
130+
def test_check_grad_ingore_y(self):
131+
place = paddle.XPUPlace(0)
132+
paddle.enable_static()
133+
self.check_grad_with_place(
134+
place, ['X'],
135+
'Out',
136+
max_relative_error=0.1,
137+
no_grad_set=set('Y'))
138+
139+
140+
support_types = get_xpu_op_support_types('mul')
141+
for stype in support_types:
142+
create_test_class(globals(), XPUTestMulOp, stype)
128143

129144
if __name__ == "__main__":
145+
paddle.enable_static()
130146
unittest.main()

0 commit comments

Comments
 (0)