Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f56556c
init commit
huangxu96 Dec 7, 2021
a292f99
init commit
huangxu96 Dec 7, 2021
5a80883
add put_along_axis_op and unitest
huangxu96 Dec 14, 2021
1e728a6
fix a lot of bug
huangxu96 Dec 21, 2021
1d21b57
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
huangxu96 Dec 21, 2021
54347f9
for CI
huangxu96 Dec 21, 2021
9d5b99d
fix cmake depency problem in CI
huangxu96 Dec 21, 2021
42a090f
modified as review suggestion and fix Rocm CI problem.
huangxu96 Dec 23, 2021
466cd27
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
huangxu96 Dec 28, 2021
44c624c
split this PR into two parts, this is the put_along_axis_op part
huangxu96 Dec 28, 2021
79ad7ec
fix a bug in broadcast in python level
huangxu96 Dec 28, 2021
1581067
fix a bug in caculate gradient of value
huangxu96 Dec 29, 2021
70a2b30
using TensorCopy instead directly assign
huangxu96 Dec 29, 2021
6c6c4d0
add inplace API for put_along_axis and unittest.
huangxu96 Dec 30, 2021
46e0391
used pre-commit for manipulation.py
huangxu96 Dec 30, 2021
994b4a7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhhsplendid Dec 30, 2021
c2283c7
fix error message typo and support gpu inplace computation
huangxu96 Dec 30, 2021
10deafa
fix error message format problem
huangxu96 Dec 30, 2021
b603bf2
Merge branch 'gather' of https://github.com/huangxu96/Paddle into gather
zhhsplendid Dec 30, 2021
dc7834f
Fix doc grammar error.
zhhsplendid Dec 30, 2021
e35265e
revert a wrong copy
huangxu96 Dec 30, 2021
a3cb00e
Merge branch 'gather' of https://github.com/huangxu96/Paddle into gather
zhhsplendid Dec 30, 2021
698e19f
Remove useless import
zhhsplendid Dec 30, 2021
a7dd3b6
Fix API doc problem
zhhsplendid Dec 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions paddle/fluid/operators/gather_scatter_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,17 @@ class ReduceAdd {
*self_data += *src_data;
}
};

static ReduceAdd reduce_add;

class ReduceMultiply {
public:
template <typename tensor_t>
void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data *= *src_data;
}
};
static ReduceMultiply reduce_mul;

template <typename tensor_t, typename index_t = int64_t,
bool is_scatter_like = true>
struct cpu_gather_scatter_functor {
Expand Down Expand Up @@ -75,7 +83,6 @@ struct cpu_gather_scatter_functor {
for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
}

int64_t index_idx = 0;
int64_t self_idx, src_idx;

Expand Down Expand Up @@ -141,8 +148,55 @@ void cpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
self, dim, index, src, "scatter_add_cpu", reduce_add, ctx);
}

template <typename tensor_t, typename index_t>
void cpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/true>()(
self, dim, index, src, "scatter_mul_cpu", reduce_mul, ctx);
}

template <typename tensor_t, typename index_t>
void cpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
Tensor output,
const platform::DeviceContext& ctx) {
auto* index_data = index.data<index_t>();
auto* output_data = output.data<tensor_t>();

auto index_dims = index.dims();
auto output_dims = output.dims();

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
int select_dim_size = index_dims[dim];
int output_select_dim_size = output_dims[dim];
for (int64_t i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i];
}

for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
}

int64_t index_idx = 0;
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < select_dim_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = index_data[index_idx];
int64_t replace_index = k + index * outer_dim_size +
i * outer_dim_size * output_select_dim_size;
output_data[replace_index] = 0;
index_idx++;
}
}
}
}

Instantiate_Template_Function(cpu_gather_kernel)
Instantiate_Template_Function(cpu_scatter_add_kernel)
Instantiate_Template_Function(cpu_scatter_assign_kernel)
Instantiate_Template_Function(cpu_scatter_add_kernel)
Instantiate_Template_Function(cpu_scatter_mul_kernel)
Instantiate_Template_Function(cpu_scatter_input_grad_kernel)

} // namespace operators
} // namespace paddle
85 changes: 83 additions & 2 deletions paddle/fluid/operators/gather_scatter_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ class ReduceAdd {
};
static ReduceAdd reduce_add;

class ReduceMul {
public:
template <typename tensor_t>
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data *= *src_data;
// TODO(huangxu96) platform::CudaAtomicMul(*self_data, *src_data);
}
};
static ReduceMul reduce_mul;

template <typename tensor_t, typename index_t, typename func_t,
bool is_scatter_like = true>
__global__ void GatherScatterGPUKernel(
Expand Down Expand Up @@ -141,6 +151,14 @@ void gpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result,
return;
}

template <typename tensor_t, typename index_t>
void gpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/true>()(
self, dim, index, src, "scatter_assign_gpu", tensor_assign, ctx);
}

template <typename tensor_t, typename index_t>
void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
Expand All @@ -149,9 +167,72 @@ void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
self, dim, index, src, "scatter_add_gpu", reduce_add, ctx);
}

namespace plat = paddle::platform;
template <typename tensor_t, typename index_t>
void gpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/true>()(
self, dim, index, src, "scatter_mul_gpu", reduce_mul, ctx);
}

template <typename tensor_t, typename index_t>
__global__ void ScatterInputGradGPUKernel(
tensor_t* grad_data, int dim, const index_t* index_data,
int64_t inner_dim_size, int select_dim_size, int grad_select_dim_size,
int64_t outer_dim_size, int64_t numel) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= numel) return;
int64_t i, j, k;
i = tid / (select_dim_size * outer_dim_size);
int64_t remind = tid % (select_dim_size * outer_dim_size);
j = remind / outer_dim_size;
k = remind % outer_dim_size;
index_t index = index_data[tid];
int64_t replace_index =
k + index * outer_dim_size + i * outer_dim_size * grad_select_dim_size;
grad_data[replace_index] = 0;
}
template <typename tensor_t, typename index_t>
void gpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
Tensor grad,
const platform::DeviceContext& ctx) {
auto* index_data = index.data<index_t>();
auto* grad_data = grad.data<tensor_t>();

auto index_dims = index.dims();
auto grad_dims = grad.dims();
int64_t index_size = index.numel();

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
int select_dim_size = index_dims[dim];
int grad_select_dim_size = grad_dims[dim];
for (int64_t i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i];
}

for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
}

int64_t slice_size = 1;
for (int i = 1; i < grad_dims.size(); ++i) slice_size *= grad_dims[i];

int block = 512;
int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block;
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();

ScatterInputGradGPUKernel<tensor_t, index_t><<<grid, block, 0, stream>>>(
grad_data, dim, index_data, inner_dim_size, select_dim_size,
grad_select_dim_size, outer_dim_size, index_size);
}
Instantiate_Template_Function(gpu_gather_kernel)
Instantiate_Template_Function(gpu_scatter_add_kernel)
Instantiate_Template_Function(gpu_scatter_assign_kernel)
Instantiate_Template_Function(gpu_scatter_add_kernel)
Instantiate_Template_Function(gpu_scatter_mul_kernel)
Instantiate_Template_Function(gpu_scatter_input_grad_kernel)

} // namespace operators
} // namespace paddle
25 changes: 25 additions & 0 deletions paddle/fluid/operators/gather_scatter_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,42 @@ template <typename tensor_t, typename index_t>
void cpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result,
const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
Tensor result,
const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result,
const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
Tensor result,
const platform::DeviceContext& ctx);
} // namespace operators
} // namespace paddle
138 changes: 138 additions & 0 deletions paddle/fluid/operators/put_along_axis_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/put_along_axis_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {

class PutAlongAxisOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "PutAlongAxis");
OP_INOUT_CHECK(ctx->HasInput("Index"), "Input", "Index", "PutAlongAxis");
OP_INOUT_CHECK(ctx->HasInput("Value"), "Input", "Value", "PutAlongAxis");
OP_INOUT_CHECK(ctx->HasOutput("Result"), "Output", "Result",
"PutAlongAxis");

auto index_dim = ctx->GetInputDim("Index");

ctx->SetOutputDim("Result", index_dim);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};

class PutAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "The input tensor of PutAlongAxisOp");
AddInput("Index", "The index tensor of PutAlongAxisOp");
AddInput("Value", "The value tensor of PutAlongAxisOp");
AddOutput("Result", "The result tensor of PutAlongAxisOp");
AddAttr<int>("Axis", "The axis that we do PutAlongAxis operation");
AddAttr<std::string>("Reduce", "The reduce operation for scatter")
.SetDefault("assign");
AddComment(R"DOC(
PutAlongAxis Operator.)
)DOC");
}
};

class PutAlongAxisGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Result")),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};

template <typename T>
class PutAlongAxisGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("put_along_axis_grad");
op->SetInput("Index", this->Input("Index"));
op->SetInput("Input", this->Input("Input"));

op->SetInput(framework::GradVarName("Result"), this->OutputGrad("Result"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Value"), this->InputGrad("Value"));
op->SetAttrMap(this->Attrs());
}
};

DECLARE_INPLACE_OP_INFERER(PutAlongAxisInplaceInferer, {"Input", "Result"});

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(put_along_axis, ops::PutAlongAxisOp, ops::PutAlongAxisOpMaker,
ops::PutAlongAxisGradOpMaker<paddle::framework::OpDesc>,
ops::PutAlongAxisGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::PutAlongAxisInplaceInferer);

REGISTER_OPERATOR(put_along_axis_grad, ops::PutAlongAxisGradOp);

REGISTER_OP_CPU_KERNEL(put_along_axis, ops::PutAlongAxisOpKernel<float>,
ops::PutAlongAxisOpKernel<double>,
ops::PutAlongAxisOpKernel<int>,
ops::PutAlongAxisOpKernel<uint8_t>,
ops::PutAlongAxisOpKernel<int64_t>);

REGISTER_OP_CPU_KERNEL(put_along_axis_grad,
ops::PutAlongAxisGradOpKernel<float>,
ops::PutAlongAxisGradOpKernel<double>,
ops::PutAlongAxisGradOpKernel<int>,
ops::PutAlongAxisGradOpKernel<uint8_t>,
ops::PutAlongAxisGradOpKernel<int64_t>);
Loading