Skip to content
14 changes: 14 additions & 0 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,21 @@ XPUOpMap& get_kl3_ops() {
{"index_select",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64})},
{"index_add",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"index_add_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"instance_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"instance_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
Expand Down
46 changes: 46 additions & 0 deletions paddle/phi/kernels/xpu/index_add_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/index_add_grad_kernel.h"

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/index_select_kernel.h"

namespace phi {

template <typename T, typename Context>
void IndexAddGradKernel(const Context& ctx,
const DenseTensor& index,
const DenseTensor& add_value,
const DenseTensor& out_grad,
int dim,
DenseTensor* x_grad,
DenseTensor* add_value_grad) {
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
phi::IndexSelectKernel<T, Context>(ctx, out_grad, index, dim, add_value_grad);
}

} // namespace phi

PD_REGISTER_KERNEL(index_add_grad,
XPU,
ALL_LAYOUT,
phi::IndexAddGradKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
86 changes: 86 additions & 0 deletions paddle/phi/kernels/xpu/index_add_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/index_add_kernel.h"

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void IndexAddKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& add_value,
int axis,
DenseTensor* out) {
auto index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
DataTypeToString(index_type),
DataTypeToString(DataType::INT32),
DataTypeToString(DataType::INT64)));

using XPUType = typename XPUTypeTrait<T>::Type;
auto input_dim = x.dims();
int dim = axis >= 0 ? axis : axis + input_dim.size();
auto input_vector = common::vectorize<int64_t>(input_dim);
int64_t numel = add_value.numel();
if (numel == 0) return;
ctx.template Alloc<T>(out);
int r = 0;
if (index_type == phi::DataType::INT64) {
r = xpu::index_add<XPUType, int64_t>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(add_value.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
reinterpret_cast<const int64_t*>(index.data<int64_t>()),
input_vector,
index.numel(),
dim,
(XPUType)(1.0f));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_add");
} else if (index_type == phi::DataType::INT32) {
r = xpu::index_add<XPUType, int>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(add_value.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
reinterpret_cast<const int*>(index.data<int>()),
input_vector,
index.numel(),
dim,
(XPUType)(1.0f));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_add");
}
}

} // namespace phi

PD_REGISTER_KERNEL(index_add,
XPU,
ALL_LAYOUT,
phi::IndexAddKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
int64_t,
int32_t) {}
35 changes: 20 additions & 15 deletions paddle/phi/kernels/xpu/index_select_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/phi/kernels/index_select_kernel.h"

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
Expand Down Expand Up @@ -40,10 +41,11 @@ void IndexSelectKernel(const Context& ctx,
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
using XPUType = typename XPUTypeTrait<T>::Type;
auto* in_data = x.data<T>();
std::vector<int> in_shape = common::vectorize<int>(input_dim);
int index_len = output->dims()[dim];
T* out_data = ctx.template Alloc<T>(output);
ctx.template Alloc<T>(output);
int r = 0;
xpu::ctx_guard RAII_GUARD(ctx.x_context());
int8_t* index_ptr = nullptr; // temp xpu buffer
Expand All @@ -67,23 +69,24 @@ void IndexSelectKernel(const Context& ctx,
const int64_t* index_data =
index_ptr ? reinterpret_cast<const int64_t*>(index_ptr)
: index.template data<int64_t>();
r = xpu::gather<T, int64_t>(ctx.x_context(),
in_data,
index_data,
out_data,
in_shape,
index_len,
dim);
r = xpu::gather<XPUType, int64_t>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(in_data),
reinterpret_cast<const int64_t*>(index_data),
reinterpret_cast<XPUType*>(output->data<T>()),
in_shape,
index_len,
dim);
} else {
const int* index_data = index_ptr ? reinterpret_cast<const int*>(index_ptr)
: index.template data<int>();
r = xpu::gather<T, int>(ctx.x_context(),
in_data,
index_data,
out_data,
in_shape,
index_len,
dim);
r = xpu::gather<XPUType, int>(ctx.x_context(),
reinterpret_cast<const XPUType*>(in_data),
reinterpret_cast<const int*>(index_data),
reinterpret_cast<XPUType*>(output->data<T>()),
in_shape,
index_len,
dim);
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
}
Expand All @@ -95,5 +98,7 @@ PD_REGISTER_KERNEL(index_select,
ALL_LAYOUT,
phi::IndexSelectKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
Loading