Skip to content

Commit 8fc685d

Browse files
authored
[XPU] support index_add and index_add_grad in xpu, support bf16 and fp16 for index_select (#67955)
* support index_add and index_add_grad in xpu, support bf16 and fp16 for index_select * fix the code style * fix test_index_add_op_xpu.py code style * pre-commit to fix the code style * fix test_index_add_op_xpu * fix the test: test_index_add_op_xpu.py * fix the test, skip xpu2
1 parent 3a529b8 commit 8fc685d

File tree

5 files changed

+521
-15
lines changed

5 files changed

+521
-15
lines changed

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,21 @@ XPUOpMap& get_kl3_ops() {
612612
{"index_select",
613613
XPUKernelSet({phi::DataType::FLOAT32,
614614
phi::DataType::INT32,
615+
phi::DataType::FLOAT16,
616+
phi::DataType::BFLOAT16,
615617
phi::DataType::INT64})},
618+
{"index_add",
619+
XPUKernelSet({phi::DataType::FLOAT32,
620+
phi::DataType::INT32,
621+
phi::DataType::INT64,
622+
phi::DataType::FLOAT16,
623+
phi::DataType::BFLOAT16})},
624+
{"index_add_grad",
625+
XPUKernelSet({phi::DataType::FLOAT32,
626+
phi::DataType::INT32,
627+
phi::DataType::INT64,
628+
phi::DataType::FLOAT16,
629+
phi::DataType::BFLOAT16})},
616630
{"instance_norm",
617631
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
618632
{"instance_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/index_add_grad_kernel.h"
16+
17+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/core/utils/data_type.h"
20+
#include "paddle/phi/kernels/index_select_kernel.h"
21+
22+
namespace phi {
23+
24+
template <typename T, typename Context>
25+
void IndexAddGradKernel(const Context& ctx,
26+
const DenseTensor& index,
27+
const DenseTensor& add_value,
28+
const DenseTensor& out_grad,
29+
int dim,
30+
DenseTensor* x_grad,
31+
DenseTensor* add_value_grad) {
32+
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
33+
phi::IndexSelectKernel<T, Context>(ctx, out_grad, index, dim, add_value_grad);
34+
}
35+
36+
} // namespace phi
37+
38+
PD_REGISTER_KERNEL(index_add_grad,
39+
XPU,
40+
ALL_LAYOUT,
41+
phi::IndexAddGradKernel,
42+
float,
43+
phi::dtype::float16,
44+
phi::dtype::bfloat16,
45+
int,
46+
int64_t) {}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/index_add_kernel.h"
16+
17+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
20+
namespace phi {
21+
22+
template <typename T, typename Context>
23+
void IndexAddKernel(const Context& ctx,
24+
const DenseTensor& x,
25+
const DenseTensor& index,
26+
const DenseTensor& add_value,
27+
int axis,
28+
DenseTensor* out) {
29+
auto index_type = index.dtype();
30+
bool index_type_match =
31+
index_type == DataType::INT32 || index_type == DataType::INT64;
32+
PADDLE_ENFORCE_EQ(index_type_match,
33+
true,
34+
errors::InvalidArgument(
35+
"Input(Index) holds the wrong type, it holds %s, but "
36+
"desires to be %s or %s",
37+
DataTypeToString(index_type),
38+
DataTypeToString(DataType::INT32),
39+
DataTypeToString(DataType::INT64)));
40+
41+
using XPUType = typename XPUTypeTrait<T>::Type;
42+
auto input_dim = x.dims();
43+
int dim = axis >= 0 ? axis : axis + input_dim.size();
44+
auto input_vector = common::vectorize<int64_t>(input_dim);
45+
int64_t numel = add_value.numel();
46+
if (numel == 0) return;
47+
ctx.template Alloc<T>(out);
48+
int r = 0;
49+
if (index_type == phi::DataType::INT64) {
50+
r = xpu::index_add<XPUType, int64_t>(
51+
ctx.x_context(),
52+
reinterpret_cast<const XPUType*>(x.data<T>()),
53+
reinterpret_cast<const XPUType*>(add_value.data<T>()),
54+
reinterpret_cast<XPUType*>(out->data<T>()),
55+
reinterpret_cast<const int64_t*>(index.data<int64_t>()),
56+
input_vector,
57+
index.numel(),
58+
dim,
59+
(XPUType)(1.0f));
60+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_add");
61+
} else if (index_type == phi::DataType::INT32) {
62+
r = xpu::index_add<XPUType, int>(
63+
ctx.x_context(),
64+
reinterpret_cast<const XPUType*>(x.data<T>()),
65+
reinterpret_cast<const XPUType*>(add_value.data<T>()),
66+
reinterpret_cast<XPUType*>(out->data<T>()),
67+
reinterpret_cast<const int*>(index.data<int>()),
68+
input_vector,
69+
index.numel(),
70+
dim,
71+
(XPUType)(1.0f));
72+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_add");
73+
}
74+
}
75+
76+
} // namespace phi
77+
78+
PD_REGISTER_KERNEL(index_add,
79+
XPU,
80+
ALL_LAYOUT,
81+
phi::IndexAddKernel,
82+
phi::dtype::float16,
83+
phi::dtype::bfloat16,
84+
float,
85+
int64_t,
86+
int32_t) {}

paddle/phi/kernels/xpu/index_select_kernel.cc

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/phi/kernels/index_select_kernel.h"
16+
1617
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1718
#include "paddle/phi/common/memory_utils.h"
1819
#include "paddle/phi/core/kernel_registry.h"
@@ -40,10 +41,11 @@ void IndexSelectKernel(const Context& ctx,
4041
index_type,
4142
phi::DataType::INT32,
4243
phi::DataType::INT64));
44+
using XPUType = typename XPUTypeTrait<T>::Type;
4345
auto* in_data = x.data<T>();
4446
std::vector<int> in_shape = common::vectorize<int>(input_dim);
4547
int index_len = output->dims()[dim];
46-
T* out_data = ctx.template Alloc<T>(output);
48+
ctx.template Alloc<T>(output);
4749
int r = 0;
4850
xpu::ctx_guard RAII_GUARD(ctx.x_context());
4951
int8_t* index_ptr = nullptr; // temp xpu buffer
@@ -67,23 +69,24 @@ void IndexSelectKernel(const Context& ctx,
6769
const int64_t* index_data =
6870
index_ptr ? reinterpret_cast<const int64_t*>(index_ptr)
6971
: index.template data<int64_t>();
70-
r = xpu::gather<T, int64_t>(ctx.x_context(),
71-
in_data,
72-
index_data,
73-
out_data,
74-
in_shape,
75-
index_len,
76-
dim);
72+
r = xpu::gather<XPUType, int64_t>(
73+
ctx.x_context(),
74+
reinterpret_cast<const XPUType*>(in_data),
75+
reinterpret_cast<const int64_t*>(index_data),
76+
reinterpret_cast<XPUType*>(output->data<T>()),
77+
in_shape,
78+
index_len,
79+
dim);
7780
} else {
7881
const int* index_data = index_ptr ? reinterpret_cast<const int*>(index_ptr)
7982
: index.template data<int>();
80-
r = xpu::gather<T, int>(ctx.x_context(),
81-
in_data,
82-
index_data,
83-
out_data,
84-
in_shape,
85-
index_len,
86-
dim);
83+
r = xpu::gather<XPUType, int>(ctx.x_context(),
84+
reinterpret_cast<const XPUType*>(in_data),
85+
reinterpret_cast<const int*>(index_data),
86+
reinterpret_cast<XPUType*>(output->data<T>()),
87+
in_shape,
88+
index_len,
89+
dim);
8790
}
8891
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
8992
}
@@ -95,5 +98,7 @@ PD_REGISTER_KERNEL(index_select,
9598
ALL_LAYOUT,
9699
phi::IndexSelectKernel,
97100
float,
101+
phi::dtype::float16,
102+
phi::dtype::bfloat16,
98103
int,
99104
int64_t) {}

0 commit comments

Comments
 (0)