Skip to content
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions paddle/fluid/operators/gather_scatter_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,17 @@ class ReduceAdd {
*self_data += *src_data;
}
};

static ReduceAdd reduce_add;

class ReduceMultiply {
public:
template <typename tensor_t>
void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data *= *src_data;
}
};
static ReduceMultiply reduce_mul;

template <typename tensor_t, typename index_t = int64_t,
bool is_scatter_like = true>
struct cpu_gather_scatter_functor {
Copy link

@iclementine iclementine Dec 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to split it into two classes; one for gather and the other for scatter. Reasons:

  1. don's pass arguments that is not used, for gather, reduction can be eliminated;
  2. better semantics, parameters that does not need to be mutated shoule be made a const parameter.(So we can pass const Tensor to that parameter).
  3. readability.

Copy link
Contributor Author

@huangxu96 huangxu96 Dec 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think split it is better since I think scatter and gather share a lot of logic. Making one kernel function can make the code more elegant. Don't want to make some redundant code.

Expand Down Expand Up @@ -75,7 +83,6 @@ struct cpu_gather_scatter_functor {
for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
}

int64_t index_idx = 0;
int64_t self_idx, src_idx;

Expand Down Expand Up @@ -141,8 +148,55 @@ void cpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
self, dim, index, src, "scatter_add_cpu", reduce_add, ctx);
}

template <typename tensor_t, typename index_t>
void cpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/true>()(
self, dim, index, src, "scatter_mul_cpu", reduce_mul, ctx);
}

template <typename tensor_t, typename index_t>
void cpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
Tensor output,
const platform::DeviceContext& ctx) {
auto* index_data = index.data<index_t>();
auto* output_data = output.data<tensor_t>();

auto index_dims = index.dims();
auto output_dims = output.dims();

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
int select_dim_size = index_dims[dim];
int output_select_dim_size = output_dims[dim];
for (int64_t i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i];
}

for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
}

int64_t index_idx = 0;
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < select_dim_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = index_data[index_idx];
int64_t replace_index = k + index * outer_dim_size +
i * outer_dim_size * output_select_dim_size;
output_data[replace_index] = 0;
index_idx++;
}
}
}
}

Instantiate_Template_Function(cpu_gather_kernel)
Instantiate_Template_Function(cpu_scatter_add_kernel)
Instantiate_Template_Function(cpu_scatter_assign_kernel)
Instantiate_Template_Function(cpu_scatter_add_kernel)
Instantiate_Template_Function(cpu_scatter_mul_kernel)
Instantiate_Template_Function(cpu_scatter_input_grad_kernel)

} // namespace operators
} // namespace paddle
85 changes: 83 additions & 2 deletions paddle/fluid/operators/gather_scatter_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ class ReduceAdd {
};
static ReduceAdd reduce_add;

class ReduceMul {
public:
template <typename tensor_t>
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data *= *src_data;
// TODO(huangxu96) platform::CudaAtomicMul(*self_data, *src_data);
}
};
static ReduceMul reduce_mul;

template <typename tensor_t, typename index_t, typename func_t,
bool is_scatter_like = true>
__global__ void GatherScatterGPUKernel(
Expand Down Expand Up @@ -141,6 +151,14 @@ void gpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result,
return;
}

template <typename tensor_t, typename index_t>
void gpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/true>()(
self, dim, index, src, "scatter_assign_gpu", tensor_assign, ctx);
}

template <typename tensor_t, typename index_t>
void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
Expand All @@ -149,9 +167,72 @@ void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
self, dim, index, src, "scatter_add_gpu", reduce_add, ctx);
}

namespace plat = paddle::platform;
template <typename tensor_t, typename index_t>
void gpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/true>()(
self, dim, index, src, "scatter_mul_gpu", reduce_mul, ctx);
}

template <typename tensor_t, typename index_t>
__global__ void ScatterInputGradGPUKernel(
tensor_t* grad_data, int dim, const index_t* index_data,
int64_t inner_dim_size, int select_dim_size, int grad_select_dim_size,
int64_t outer_dim_size, int64_t numel) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= numel) return;
int64_t i, j, k;
i = tid / (select_dim_size * outer_dim_size);
int64_t remind = tid % (select_dim_size * outer_dim_size);
j = remind / outer_dim_size;
k = remind % outer_dim_size;
index_t index = index_data[tid];
int64_t replace_index =
k + index * outer_dim_size + i * outer_dim_size * grad_select_dim_size;
grad_data[replace_index] = 0;
}
template <typename tensor_t, typename index_t>
void gpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
Tensor grad,
const platform::DeviceContext& ctx) {
auto* index_data = index.data<index_t>();
auto* grad_data = grad.data<tensor_t>();

auto index_dims = index.dims();
auto grad_dims = grad.dims();
int64_t index_size = index.numel();

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
int select_dim_size = index_dims[dim];
int grad_select_dim_size = grad_dims[dim];
for (int64_t i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i];
}

for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
}

int64_t slice_size = 1;
for (int i = 1; i < grad_dims.size(); ++i) slice_size *= grad_dims[i];

int block = 512;
int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block;
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();

ScatterInputGradGPUKernel<tensor_t, index_t><<<grid, block, 0, stream>>>(
grad_data, dim, index_data, inner_dim_size, select_dim_size,
grad_select_dim_size, outer_dim_size, index_size);
}
Instantiate_Template_Function(gpu_gather_kernel)
Instantiate_Template_Function(gpu_scatter_add_kernel)
Instantiate_Template_Function(gpu_scatter_assign_kernel)
Instantiate_Template_Function(gpu_scatter_add_kernel)
Instantiate_Template_Function(gpu_scatter_mul_kernel)
Instantiate_Template_Function(gpu_scatter_input_grad_kernel)

} // namespace operators
} // namespace paddle
25 changes: 25 additions & 0 deletions paddle/fluid/operators/gather_scatter_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,42 @@ template <typename tensor_t, typename index_t>
void cpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

  1. const Tensor& self as const input parameter;
  2. Tensor* as output parameter;
  3. device context as the first parameter(a convention in paddle followed by many functions that take tensor as parameter in paddle).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I will modify the self tensor in my code, I declare using a const variable is not necessary here.

const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
Tensor result,
const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result,
const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
Tensor result,
const platform::DeviceContext& ctx);
} // namespace operators
} // namespace paddle
149 changes: 149 additions & 0 deletions paddle/fluid/operators/put_along_axis_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/put_along_axis_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {

class PutAlongAxisOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
Copy link
Contributor

@chenwhql chenwhql Dec 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议使用OP_INOUT_CHECK

ctx->HasInput("Input"), true,
platform::errors::InvalidArgument(
"Input(Input) of PutAlongAxisOpOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Index"), true,
platform::errors::InvalidArgument(
"Input(Index) of PutAlongAxisOpOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Value"), true,
platform::errors::InvalidArgument(
"Input(Value) of PutAlongAxisOpOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Result"), true,
platform::errors::InvalidArgument(
"Output(Result) of PutAlongAxisOpOp should not be null."));

auto index_dim = ctx->GetInputDim("Index");

ctx->SetOutputDim("Result", index_dim);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};

class PutAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "The input tensor of PutAlongAxisOp");
AddInput("Index", "The index tensor of PutAlongAxisOp");
AddInput("Value", "The value tensor of PutAlongAxisOp");
AddOutput("Result", "The result tensor of PutAlongAxisOp");
AddAttr<int>("Axis", "The axis that we do PutAlongAxis operation");
AddAttr<std::string>("Reduce", "The reduce operation for scatter")
.SetDefault("assign");
AddComment(R"DOC(
PutAlongAxis Operator.)
)DOC");
}
};

class PutAlongAxisGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Result")),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};

template <typename T>
class PutAlongAxisGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("put_along_axis_grad");
op->SetInput("Index", this->Input("Index"));
op->SetInput("Input", this->Input("Input"));

op->SetInput(framework::GradVarName("Result"), this->OutputGrad("Result"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Value"), this->InputGrad("Value"));
op->SetAttrMap(this->Attrs());
}
};

DECLARE_INPLACE_OP_INFERER(PutAlongAxisInplaceInferer, {"Input", "Result"});

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(put_along_axis, ops::PutAlongAxisOp, ops::PutAlongAxisOpMaker,
ops::PutAlongAxisGradOpMaker<paddle::framework::OpDesc>,
ops::PutAlongAxisGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::PutAlongAxisInplaceInferer);

REGISTER_OPERATOR(put_along_axis_grad, ops::PutAlongAxisGradOp);

REGISTER_OP_CPU_KERNEL(put_along_axis, ops::PutAlongAxisOpKernel<float>,
ops::PutAlongAxisOpKernel<double>,
ops::PutAlongAxisOpKernel<int>,
ops::PutAlongAxisOpKernel<uint8_t>,
ops::PutAlongAxisOpKernel<int64_t>);

REGISTER_OP_CPU_KERNEL(put_along_axis_grad,
ops::PutAlongAxisGradOpKernel<float>,
ops::PutAlongAxisGradOpKernel<double>,
ops::PutAlongAxisGradOpKernel<int>,
ops::PutAlongAxisGradOpKernel<uint8_t>,
ops::PutAlongAxisGradOpKernel<int64_t>);
Loading