Skip to content

Commit a7a6887

Browse files
committed
support argmax bf16
1 parent 8a62e37 commit a7a6887

File tree

4 files changed

+17
-5
lines changed

4 files changed

+17
-5
lines changed

cmake/external/xpu.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ if(NOT DEFINED XPU_XRE_BASE_VERSION)
3131
set(XPU_XRE_BASE_VERSION "4.32.0.1")
3232
endif()
3333
if(NOT DEFINED XPU_XHPC_BASE_DATE)
34-
set(XPU_XHPC_BASE_DATE "eb35/20241024")
34+
set(XPU_XHPC_BASE_DATE "eb35/20241104")
3535
endif()
3636
set(XPU_XCCL_BASE_VERSION "1.2.11e")
3737
if(NOT DEFINED XPU_XFT_BASE_VERSION)

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ XPUOpMap& get_kl3_ops() {
5858
XPUKernelSet({phi::DataType::INT32,
5959
phi::DataType::INT64,
6060
phi::DataType::FLOAT32,
61-
phi::DataType::FLOAT16})},
61+
phi::DataType::FLOAT16,
62+
phi::DataType::BFLOAT16})},
6263
{"arg_min",
6364
XPUKernelSet({phi::DataType::FLOAT32,
6465
phi::DataType::FLOAT16,

paddle/phi/kernels/xpu/arg_min_max_kernel.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ PD_REGISTER_KERNEL(argmax,
193193
float,
194194
int,
195195
int64_t,
196-
phi::dtype::float16) {
196+
phi::dtype::float16,
197+
phi::dtype::bfloat16) {
197198
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
198199
}
199200

test/xpu/test_arg_max_op_xpu.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
create_test_class,
2121
get_xpu_op_support_types,
2222
)
23+
from op_test import convert_float_to_uint16
2324
from op_test_xpu import XPUOpTest
2425

2526
import paddle
@@ -41,8 +42,17 @@ def setUp(self):
4142
self.dtype = self.in_type
4243
self.initTestCase()
4344

44-
self.x = (np.random.random(self.dims)).astype(self.dtype)
45-
self.inputs = {'X': self.x}
45+
self.x = (np.random.random(self.dims)).astype(
46+
self.dtype if self.dtype != np.uint16 else np.float32
47+
)
48+
49+
self.inputs = {
50+
'X': (
51+
self.x
52+
if self.dtype != np.uint16
53+
else convert_float_to_uint16(self.x)
54+
)
55+
}
4656
self.attrs = {'axis': self.axis, 'use_xpu': True}
4757
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
4858

0 commit comments

Comments
 (0)