Skip to content

Commit bda614d

Browse files
authored
[XPU] bfloat16 support for gather/gather_grad/scatter/scatter_grad (PaddlePaddle#71132)
1 parent 4f1e774 commit bda614d

File tree

11 files changed

+70
-15
lines changed

11 files changed

+70
-15
lines changed

cmake/external/xpu.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ set(XPU_XFA_LIB_NAME "libxpu_flash_attention.so")
3030
set(XPU_XPUDNN_LIB_NAME "libxpu_dnn.so")
3131

3232
if(NOT DEFINED XPU_XHPC_BASE_DATE)
33-
set(XPU_XHPC_BASE_DATE "dev/20250204")
33+
set(XPU_XHPC_BASE_DATE "dev/20250213")
3434
endif()
3535
set(XPU_XCCL_BASE_VERSION "3.0.2.3") # For XRE5
3636
if(NOT DEFINED XPU_XFT_BASE_VERSION)

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,14 @@ XPUOpMap& get_kl3_ops() {
701701
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
702702
{"floor", XPUKernelSet({phi::DataType::FLOAT32})},
703703
{"gather_grad",
704-
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
704+
XPUKernelSet({phi::DataType::FLOAT32,
705+
phi::DataType::FLOAT16,
706+
phi::DataType::BFLOAT16,
707+
phi::DataType::INT8,
708+
phi::DataType::INT16,
709+
phi::DataType::INT32,
710+
phi::DataType::INT64,
711+
phi::DataType::BOOL})},
705712
{"gather_nd_grad",
706713
XPUKernelSet({phi::DataType::INT32,
707714
phi::DataType::INT64,
@@ -717,6 +724,9 @@ XPUOpMap& get_kl3_ops() {
717724
{"gather",
718725
XPUKernelSet({phi::DataType::FLOAT32,
719726
phi::DataType::FLOAT16,
727+
phi::DataType::BFLOAT16,
728+
phi::DataType::INT8,
729+
phi::DataType::INT16,
720730
phi::DataType::INT32,
721731
phi::DataType::INT64,
722732
phi::DataType::BOOL})},
@@ -1154,9 +1164,13 @@ XPUOpMap& get_kl3_ops() {
11541164
{"scatter",
11551165
XPUKernelSet({phi::DataType::INT64,
11561166
phi::DataType::INT32,
1167+
phi::DataType::FLOAT16,
1168+
phi::DataType::BFLOAT16,
11571169
phi::DataType::FLOAT32})},
11581170
{"scatter_grad",
1159-
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
1171+
XPUKernelSet({phi::DataType::FLOAT16,
1172+
phi::DataType::BFLOAT16,
1173+
phi::DataType::FLOAT32})},
11601174
{"scatter_nd_add",
11611175
XPUKernelSet({phi::DataType::FLOAT32,
11621176
phi::DataType::INT32,

paddle/phi/kernels/cpu/gather_grad_kernel.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,12 @@ PD_REGISTER_KERNEL(gather_grad,
7272
phi::GatherGradKernel,
7373
float,
7474
double,
75-
int,
7675
uint8_t,
76+
int8_t,
77+
int16_t,
78+
int32_t,
7779
int64_t,
80+
bool,
7881
phi::dtype::bfloat16,
7982
phi::dtype::complex<float>,
8083
phi::dtype::complex<double>) {}

paddle/phi/kernels/cpu/gather_kernel.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,12 @@ PD_REGISTER_KERNEL(gather,
7070
phi::GatherKernel,
7171
float,
7272
double,
73-
int,
7473
uint8_t,
74+
int8_t,
75+
int16_t,
76+
int32_t,
7577
int64_t,
78+
bool,
7679
phi::dtype::bfloat16,
7780
phi::dtype::complex<float>,
7881
phi::dtype::complex<double>) {}

paddle/phi/kernels/fusion/xpu/fused_rope_utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,11 +358,13 @@ void XPUFusedRotaryHalf(const Context& dev_ctx,
358358
nullptr,
359359
reinterpret_cast<const XPUSCType*>(sin_data),
360360
reinterpret_cast<const XPUSCType*>(cos_data),
361+
nullptr,
361362
reinterpret_cast<XPUType*>(out_q->data()),
362363
nullptr,
363364
{batch_size, seq_len, num_heads, head_dim},
364365
{batch_size, seq_len, 1, head_dim},
365366
{},
367+
0,
366368
"BLHD",
367369
-1,
368370
10000.0f);
@@ -374,11 +376,13 @@ void XPUFusedRotaryHalf(const Context& dev_ctx,
374376
reinterpret_cast<const XPUType*>(in_k->data()),
375377
reinterpret_cast<const XPUSCType*>(sin_data),
376378
reinterpret_cast<const XPUSCType*>(cos_data),
379+
nullptr,
377380
reinterpret_cast<XPUType*>(out_q->data()),
378381
reinterpret_cast<XPUType*>(out_k->data()),
379382
{batch_size, seq_len, num_heads, head_dim},
380383
{batch_size, seq_len, 1, head_dim},
381384
{},
385+
0,
382386
"BLHD",
383387
num_heads_k,
384388
10000.0f);
@@ -392,11 +396,13 @@ void XPUFusedRotaryHalf(const Context& dev_ctx,
392396
nullptr,
393397
reinterpret_cast<const XPUSCType*>(sin_data),
394398
reinterpret_cast<const XPUSCType*>(cos_data),
399+
nullptr,
395400
reinterpret_cast<XPUType*>(out_v->data()),
396401
nullptr,
397402
{batch_size, seq_len, num_heads_v, head_dim},
398403
{batch_size, seq_len, 1, head_dim},
399404
{},
405+
0,
400406
"BLHD",
401407
-1,
402408
10000.0f);

paddle/phi/kernels/xpu/gather_grad_kernel.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,10 @@ PD_REGISTER_KERNEL(gather_grad,
101101
ALL_LAYOUT,
102102
phi::GatherGradKernel,
103103
float,
104-
phi::dtype::float16) {}
104+
phi::dtype::float16,
105+
phi::dtype::bfloat16,
106+
int8_t,
107+
int16_t,
108+
int32_t,
109+
int64_t,
110+
bool) {}

paddle/phi/kernels/xpu/gather_kernel.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ PD_REGISTER_KERNEL(gather,
8888
phi::GatherKernel,
8989
float,
9090
phi::dtype::float16,
91-
int,
91+
phi::dtype::bfloat16,
92+
int8_t,
93+
int16_t,
94+
int32_t,
9295
int64_t,
9396
bool) {}

paddle/phi/kernels/xpu/scatter_grad_kernel.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,5 @@ PD_REGISTER_KERNEL(scatter_grad,
9393
ALL_LAYOUT,
9494
phi::ScatterGradKernel,
9595
float,
96-
phi::dtype::float16) {}
96+
phi::dtype::float16,
97+
phi::dtype::bfloat16) {}

paddle/phi/kernels/xpu/scatter_kernel.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ PD_REGISTER_KERNEL(scatter,
119119
ALL_LAYOUT,
120120
phi::ScatterKernel,
121121
float,
122-
int,
122+
int32_t,
123123
int64_t,
124-
phi::dtype::float16) {}
124+
phi::dtype::float16,
125+
phi::dtype::bfloat16) {}

test/xpu/test_gather_op_xpu.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
create_test_class,
2121
get_xpu_op_support_types,
2222
)
23+
from op_test import convert_float_to_uint16
2324
from op_test_xpu import XPUOpTest
2425

2526
import paddle
@@ -45,7 +46,11 @@ def setUp(self):
4546
self.dtype = self.in_type
4647

4748
self.init_config()
48-
xnp = np.random.random(self.x_shape).astype(self.dtype)
49+
xnp = np.random.random(self.x_shape).astype(
50+
self.dtype if self.dtype != np.uint16 else np.float32
51+
)
52+
if self.dtype == np.uint16:
53+
xnp = convert_float_to_uint16(xnp)
4954
self.inputs = {
5055
'X': xnp,
5156
'Index': np.array(self.index).astype(self.index_type),

0 commit comments

Comments
 (0)