Skip to content

Commit 57d95d1

Browse files
authored
[XPU] bind Addmm (#68560)
* [XPU] bind Addmm * fix
1 parent b7624cf commit 57d95d1

File tree

10 files changed

+637
-19
lines changed

10 files changed

+637
-19
lines changed

paddle/phi/backends/xpu/xpu2_op_list.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ XPUOpMap& get_kl2_ops() {
4646
{"adagrad", XPUKernelSet({phi::DataType::FLOAT32})},
4747
{"addcmul_xpu",
4848
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
49+
{"addmm", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
50+
{"addmm_grad",
51+
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
4952
{"arange_tensor",
5053
XPUKernelSet({phi::DataType::FLOAT32,
5154
phi::DataType::INT32,

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ XPUOpMap& get_kl3_ops() {
4040
{"adagrad", XPUKernelSet({phi::DataType::FLOAT32})},
4141
{"addcmul_xpu",
4242
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
43+
{"addmm",
44+
XPUKernelSet({phi::DataType::FLOAT32,
45+
phi::DataType::FLOAT16,
46+
phi::DataType::BFLOAT16})},
47+
{"addmm_grad",
48+
XPUKernelSet({phi::DataType::FLOAT32,
49+
phi::DataType::FLOAT16,
50+
phi::DataType::BFLOAT16})},
4351
{"arange_tensor",
4452
XPUKernelSet({phi::DataType::FLOAT32,
4553
phi::DataType::INT32,

paddle/phi/kernels/fusion/xpu/fused_feedforward_grad_kernel.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,10 @@ void FFNGrad(const phi::XPUContext& dev_ctx,
246246
}
247247

248248
phi::MatMulXPUFunction<XPUTypeT>(
249-
xpu_ctx, a_1, b_1, c_1, info_d_dropout1, 1.0f, true);
249+
xpu_ctx, a_1, b_1, c_1, info_d_dropout1, 1.0f, 0.f, true);
250250

251251
phi::MatMulXPUFunction<XPUTypeT>(
252-
xpu_ctx, a_2, b_2, c_2, info_dw2, 1.0f, true);
252+
xpu_ctx, a_2, b_2, c_2, info_dw2, 1.0f, 0.f, true);
253253

254254
// dropout_grad1
255255
DropoutGrad(xpu_ctx,
@@ -335,10 +335,11 @@ void FFNGrad(const phi::XPUContext& dev_ctx,
335335

336336
std::tie(info_dx, info_dw1, a_1, b_1, a_2, b_2) = fc_info;
337337

338-
phi::MatMulXPUFunction<XPUTypeT>(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f, true);
338+
phi::MatMulXPUFunction<XPUTypeT>(
339+
xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f, 0.f, true);
339340

340341
phi::MatMulXPUFunction<XPUTypeT>(
341-
xpu_ctx, a_2, b_2, c_2, info_dw1, 1.0f, true);
342+
xpu_ctx, a_2, b_2, c_2, info_dw1, 1.0f, 0.f, true);
342343

343344
if (pre_layer_norm) {
344345
r = xpu::layer_norm_grad(xpu_ctx,

paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ void FusedGemmEpilogueKernel(const Context& dev_ctx,
7373
"FusedGemm do not support batched fc now, but got batch size %d.",
7474
batch_size));
7575
MatMulXPUFunction<XPUType>(
76-
xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f, false, act);
76+
xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f, 0.f, false, act);
7777
}
7878

7979
} // namespace fusion
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/addmm_grad_kernel.h"
16+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
17+
#include "paddle/phi/backends/xpu/xpu_context.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
20+
21+
namespace phi {
22+
23+
template <typename T, typename Context>
24+
void AddmmGradKernel(const Context& dev_ctx,
25+
const DenseTensor& input,
26+
const DenseTensor& x,
27+
const DenseTensor& y,
28+
const DenseTensor& out_grad,
29+
float alpha,
30+
float beta,
31+
DenseTensor* input_grad,
32+
DenseTensor* x_grad,
33+
DenseTensor* y_grad) {
34+
using XPUType = typename XPUTypeTrait<T>::Type;
35+
36+
xpu::Context* xpu_ctx = dev_ctx.x_context();
37+
xpu::ctx_guard RAII_GUARD(xpu_ctx);
38+
int r;
39+
40+
if (input_grad) {
41+
dev_ctx.template Alloc<T>(input_grad);
42+
XPUType* input_grad_ptr = reinterpret_cast<XPUType*>(input_grad->data<T>());
43+
r = xpu::constant(xpu_ctx, input_grad_ptr, input.numel(), (XPUType)(beta));
44+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
45+
if (input_grad->dims().size() == 1 && out_grad.dims()[0] > 1) {
46+
r = xpu::scale<XPUType>(xpu_ctx,
47+
input_grad_ptr,
48+
input_grad_ptr,
49+
input_grad->numel(),
50+
true,
51+
static_cast<float>(out_grad.dims()[0]),
52+
0.f);
53+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
54+
}
55+
}
56+
if (x_grad) {
57+
dev_ctx.template Alloc<T>(x_grad);
58+
}
59+
if (y_grad) {
60+
dev_ctx.template Alloc<T>(y_grad);
61+
}
62+
63+
const XPUType* out_grad_ptr =
64+
reinterpret_cast<const XPUType*>(out_grad.data<T>());
65+
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x.data<T>());
66+
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y.data<T>());
67+
68+
XpuFcInfo info_forward;
69+
GetFCInfo(x.dims(), y.dims(), false, false, &info_forward);
70+
// begin calculate
71+
const XPUType* a_1 = nullptr;
72+
const XPUType* b_1 = nullptr;
73+
const XPUType* a_2 = nullptr;
74+
const XPUType* b_2 = nullptr;
75+
XPUType* c_1 = reinterpret_cast<XPUType*>(x_grad->data<T>());
76+
XPUType* c_2 = reinterpret_cast<XPUType*>(y_grad->data<T>());
77+
78+
if (x_grad && info_forward.is_x_need_broadcast) {
79+
c_1 = RAII_GUARD.alloc_l3_or_gm<XPUType>(info_forward.bs * info_forward.m *
80+
info_forward.k);
81+
PADDLE_ENFORCE_XDNN_NOT_NULL(c_1);
82+
}
83+
84+
if (y_grad && info_forward.is_y_need_broadcast) {
85+
c_2 = RAII_GUARD.alloc_l3_or_gm<XPUType>(info_forward.bs * info_forward.k *
86+
info_forward.n);
87+
PADDLE_ENFORCE_XDNN_NOT_NULL(c_2);
88+
}
89+
90+
XpuFcInfo info_x_grad;
91+
XpuFcInfo info_y_grad;
92+
std::tuple<XpuFcInfo,
93+
XpuFcInfo,
94+
const XPUType*,
95+
const XPUType*,
96+
const XPUType*,
97+
const XPUType*>
98+
fc_info = MatmulGradFcInfo(xpu_ctx,
99+
&RAII_GUARD,
100+
info_forward,
101+
false,
102+
false,
103+
x_ptr,
104+
y_ptr,
105+
out_grad_ptr);
106+
std::tie(info_x_grad, info_y_grad, a_1, b_1, a_2, b_2) = fc_info;
107+
if (x_grad) {
108+
MatMulXPUFunction<XPUType>(xpu_ctx, a_1, b_1, c_1, info_x_grad, alpha, 0.f);
109+
if (info_forward.is_x_need_broadcast) {
110+
r = xpu::reduce_sum<XPUType>(
111+
xpu_ctx,
112+
c_1,
113+
reinterpret_cast<XPUType*>(x_grad->data<T>()),
114+
{info_forward.bs, info_forward.m, info_forward.k},
115+
{0});
116+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
117+
}
118+
}
119+
if (y_grad) {
120+
MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_y_grad, alpha, 0.f);
121+
if (info_forward.is_y_need_broadcast) {
122+
r = xpu::reduce_sum<XPUType>(
123+
xpu_ctx,
124+
c_2,
125+
reinterpret_cast<XPUType*>(y_grad->data<T>()),
126+
{info_forward.bs, info_forward.k, info_forward.n},
127+
{0});
128+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
129+
}
130+
}
131+
}
132+
} // namespace phi
133+
134+
PD_REGISTER_KERNEL(addmm_grad,
135+
XPU,
136+
ALL_LAYOUT,
137+
phi::AddmmGradKernel,
138+
float,
139+
phi::dtype::bfloat16,
140+
phi::dtype::float16) {}
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/addmm_kernel.h"
16+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
17+
#include "paddle/phi/backends/xpu/xpu_context.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
#include "xblas/cublasLt.h"
20+
21+
#ifndef PADDLE_WITH_XPU_XRE5
22+
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
23+
#endif
24+
25+
namespace xblas = baidu::xpu::xblas;
26+
27+
namespace phi {
28+
29+
template <typename T, typename Context>
30+
void AddmmKernel(const Context& dev_ctx,
31+
const DenseTensor& input,
32+
const DenseTensor& x,
33+
const DenseTensor& y,
34+
float beta,
35+
float alpha,
36+
DenseTensor* out) {
37+
using XPUType = typename XPUTypeTrait<T>::Type;
38+
39+
auto input_dims = input.dims();
40+
auto x_dims = x.dims();
41+
auto y_dims = y.dims();
42+
PADDLE_ENFORCE_EQ(
43+
input_dims.size() == 2 || input_dims.size() == 1,
44+
true,
45+
common::errors::InvalidArgument(
46+
"Variable 'input' of AddmmOp must be 1-dimensional or 2-dimensional, "
47+
"but received shape: [%s]",
48+
input_dims));
49+
PADDLE_ENFORCE_EQ(x_dims.size() == 2,
50+
true,
51+
common::errors::InvalidArgument(
52+
"Variable 'x' of AddmmOp must be 2-dimensional, "
53+
"but received shape: [%s]",
54+
input_dims));
55+
PADDLE_ENFORCE_EQ(y_dims.size() == 2,
56+
true,
57+
common::errors::InvalidArgument(
58+
"Variable 'y' of AddmmOp must be 2-dimensional, "
59+
"but received shape: [%s]",
60+
input_dims));
61+
62+
dev_ctx.template Alloc<T>(out);
63+
const XPUType* x_ptr = reinterpret_cast<const XPUType*>(x.data<T>());
64+
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y.data<T>());
65+
const XPUType* input_ptr = reinterpret_cast<const XPUType*>(input.data<T>());
66+
XPUType* out_ptr = reinterpret_cast<XPUType*>(out->data<T>());
67+
68+
int r;
69+
if (alpha == 0.f) {
70+
if (beta == 0.f) {
71+
r = xpu::constant(dev_ctx.x_context(),
72+
out_ptr,
73+
out->numel(),
74+
static_cast<XPUType>(0.0f));
75+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
76+
} else {
77+
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
78+
T* beta_xpu = RAII_GUARD.alloc_l3_or_gm<T>(1);
79+
r = xpu::constant(dev_ctx.x_context(),
80+
reinterpret_cast<XPUType*>(beta_xpu),
81+
out->numel(),
82+
static_cast<XPUType>(beta));
83+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
84+
auto input_dims_vec = common::vectorize<int64_t>(input.dims());
85+
auto out_dims_vec = common::vectorize<int64_t>(out->dims());
86+
r = xpu::broadcast_mul<XPUType>(dev_ctx.x_context(),
87+
input_ptr,
88+
reinterpret_cast<XPUType*>(beta_xpu),
89+
out_ptr,
90+
input_dims_vec,
91+
out_dims_vec);
92+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_mul");
93+
}
94+
#ifdef PADDLE_WITH_XPU_XRE5
95+
} else {
96+
xblas::FcFusionTensor<const XPUType> t_input{
97+
input_ptr,
98+
nullptr,
99+
input.dims()[0],
100+
input.dims()[1],
101+
input.dims()[1],
102+
false,
103+
};
104+
xblas::FcFusionTensor<const XPUType> t_x{
105+
x_ptr,
106+
nullptr,
107+
x.dims()[0],
108+
x.dims()[1],
109+
x.dims()[1],
110+
false,
111+
};
112+
xblas::FcFusionTensor<const XPUType> t_y{
113+
y_ptr,
114+
nullptr,
115+
y.dims()[0],
116+
y.dims()[1],
117+
y.dims()[1],
118+
false,
119+
};
120+
xblas::FcFusionTensor<XPUType> t_out{
121+
out_ptr,
122+
nullptr,
123+
out->dims()[0],
124+
out->dims()[1],
125+
out->dims()[1],
126+
false,
127+
};
128+
xblas::FcFusionDesc<float, float, XPUType> desc{
129+
alpha,
130+
beta,
131+
};
132+
xblas::FcFusionEpilogue<float, float> epilogue{
133+
xdnn::Activation_t::LINEAR,
134+
nullptr,
135+
nullptr,
136+
nullptr,
137+
0,
138+
0,
139+
nullptr,
140+
};
141+
r = xblas::fc_fusion<XPUType,
142+
XPUType,
143+
XPUType,
144+
XPUType,
145+
float,
146+
float,
147+
XPUType,
148+
float,
149+
float>(
150+
dev_ctx.x_context(), t_x, t_y, t_input, t_out, desc, epilogue);
151+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_fusion");
152+
#else
153+
} else {
154+
Copy(dev_ctx, input, dev_ctx.GetPlace(), false, out);
155+
XpuFcInfo fc_info;
156+
GetFCInfo(x_dims, y_dims, false, false, &fc_info);
157+
MatMulXPUFunction<XPUType>(
158+
dev_ctx.x_context(), x_ptr, y_ptr, out_ptr, fc_info, alpha, beta);
159+
#endif
160+
}
161+
}
162+
} // namespace phi
163+
164+
PD_REGISTER_KERNEL(addmm,
165+
XPU,
166+
ALL_LAYOUT,
167+
phi::AddmmKernel,
168+
float,
169+
phi::dtype::bfloat16,
170+
phi::dtype::float16) {}

0 commit comments

Comments
 (0)