-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Put_along_axis #37921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Put_along_axis #37921
Changes from 15 commits
f56556c
a292f99
5a80883
1e728a6
1d21b57
54347f9
9d5b99d
42a090f
466cd27
44c624c
79ad7ec
1581067
70a2b30
6c6c4d0
46e0391
c2283c7
10deafa
e35265e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,17 +41,42 @@ template <typename tensor_t, typename index_t> | |
| void cpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since I will modify the self tensor in my code, I declare using a const variable is not necessary here. |
||
| const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void cpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor src, const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void cpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor src, const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void cpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor src, const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void cpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor result, | ||
| const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void gpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result, | ||
| const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void gpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor src, const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor src, const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void gpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor src, const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void gpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor result, | ||
| const platform::DeviceContext& ctx); | ||
| } // namespace operators | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/fluid/operators/put_along_axis_op.h" | ||
| #include <memory> | ||
| #include <string> | ||
| #include <vector> | ||
| #include "paddle/fluid/framework/ddim.h" | ||
| #include "paddle/fluid/framework/op_version_registry.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| class PutAlongAxisOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| void InferShape(framework::InferShapeContext* ctx) const override { | ||
| PADDLE_ENFORCE_EQ( | ||
|
||
| ctx->HasInput("Input"), true, | ||
| platform::errors::InvalidArgument( | ||
| "Input(Input) of PutAlongAxisOpOp should not be null.")); | ||
| PADDLE_ENFORCE_EQ( | ||
| ctx->HasInput("Index"), true, | ||
| platform::errors::InvalidArgument( | ||
| "Input(Index) of PutAlongAxisOpOp should not be null.")); | ||
| PADDLE_ENFORCE_EQ( | ||
| ctx->HasInput("Value"), true, | ||
| platform::errors::InvalidArgument( | ||
| "Input(Value) of PutAlongAxisOpOp should not be null.")); | ||
| PADDLE_ENFORCE_EQ( | ||
| ctx->HasOutput("Result"), true, | ||
| platform::errors::InvalidArgument( | ||
| "Output(Result) of PutAlongAxisOpOp should not be null.")); | ||
|
|
||
| auto index_dim = ctx->GetInputDim("Index"); | ||
|
|
||
| ctx->SetOutputDim("Result", index_dim); | ||
| } | ||
|
|
||
| protected: | ||
| framework::OpKernelType GetExpectedKernelType( | ||
| const framework::ExecutionContext& ctx) const override { | ||
| return framework::OpKernelType( | ||
| OperatorWithKernel::IndicateVarDataType(ctx, "Input"), | ||
| ctx.device_context()); | ||
| } | ||
| framework::OpKernelType GetKernelTypeForVar( | ||
| const std::string& var_name, const framework::Tensor& tensor, | ||
| const framework::OpKernelType& expected_kernel_type) const override { | ||
| return framework::OpKernelType(expected_kernel_type.data_type_, | ||
| tensor.place(), tensor.layout()); | ||
| } | ||
| }; | ||
|
|
||
| class PutAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| void Make() override { | ||
| AddInput("Input", "The input tensor of PutAlongAxisOp"); | ||
| AddInput("Index", "The index tensor of PutAlongAxisOp"); | ||
| AddInput("Value", "The value tensor of PutAlongAxisOp"); | ||
| AddOutput("Result", "The result tensor of PutAlongAxisOp"); | ||
| AddAttr<int>("Axis", "The axis that we do PutAlongAxis operation"); | ||
| AddAttr<std::string>("Reduce", "The reduce operation for scatter") | ||
| .SetDefault("assign"); | ||
| AddComment(R"DOC( | ||
| PutAlongAxis Operator.) | ||
| )DOC"); | ||
| } | ||
| }; | ||
|
|
||
| class PutAlongAxisGradOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| void InferShape(framework::InferShapeContext* ctx) const override { | ||
| ctx->SetOutputDim(framework::GradVarName("Input"), | ||
| ctx->GetInputDim("Input")); | ||
| } | ||
|
|
||
| protected: | ||
| framework::OpKernelType GetExpectedKernelType( | ||
| const framework::ExecutionContext& ctx) const override { | ||
| return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( | ||
| ctx, framework::GradVarName("Result")), | ||
| ctx.device_context()); | ||
| } | ||
| framework::OpKernelType GetKernelTypeForVar( | ||
| const std::string& var_name, const framework::Tensor& tensor, | ||
| const framework::OpKernelType& expected_kernel_type) const override { | ||
| return framework::OpKernelType(expected_kernel_type.data_type_, | ||
| tensor.place(), tensor.layout()); | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| class PutAlongAxisGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
| public: | ||
| using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
|
||
| protected: | ||
| void Apply(GradOpPtr<T> op) const override { | ||
| op->SetType("put_along_axis_grad"); | ||
| op->SetInput("Index", this->Input("Index")); | ||
| op->SetInput("Input", this->Input("Input")); | ||
|
|
||
| op->SetInput(framework::GradVarName("Result"), this->OutputGrad("Result")); | ||
| op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); | ||
iclementine marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| op->SetOutput(framework::GradVarName("Value"), this->InputGrad("Value")); | ||
| op->SetAttrMap(this->Attrs()); | ||
| } | ||
| }; | ||
|
|
||
| DECLARE_INPLACE_OP_INFERER(PutAlongAxisInplaceInferer, {"Input", "Result"}); | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
| REGISTER_OPERATOR(put_along_axis, ops::PutAlongAxisOp, ops::PutAlongAxisOpMaker, | ||
| ops::PutAlongAxisGradOpMaker<paddle::framework::OpDesc>, | ||
| ops::PutAlongAxisGradOpMaker<paddle::imperative::OpBase>, | ||
| paddle::operators::PutAlongAxisInplaceInferer); | ||
|
|
||
| REGISTER_OPERATOR(put_along_axis_grad, ops::PutAlongAxisGradOp); | ||
|
|
||
| REGISTER_OP_CPU_KERNEL(put_along_axis, ops::PutAlongAxisOpKernel<float>, | ||
| ops::PutAlongAxisOpKernel<double>, | ||
| ops::PutAlongAxisOpKernel<int>, | ||
| ops::PutAlongAxisOpKernel<uint8_t>, | ||
| ops::PutAlongAxisOpKernel<int64_t>); | ||
|
|
||
| REGISTER_OP_CPU_KERNEL(put_along_axis_grad, | ||
| ops::PutAlongAxisGradOpKernel<float>, | ||
| ops::PutAlongAxisGradOpKernel<double>, | ||
| ops::PutAlongAxisGradOpKernel<int>, | ||
| ops::PutAlongAxisGradOpKernel<uint8_t>, | ||
| ops::PutAlongAxisGradOpKernel<int64_t>); | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to split it into two classes; one for gather and the other for scatter. Reasons:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think split it is better since I think scatter and gather share a lot of logic. Making one kernel function can make the code more elegant. Don't want to make some redundant code.