-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Put_along_axis #37921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Put_along_axis #37921
Changes from 17 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
f56556c
init commit
huangxu96 a292f99
init commit
huangxu96 5a80883
add put_along_axis_op and unitest
huangxu96 1e728a6
fix a lot of bug
huangxu96 1d21b57
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
huangxu96 54347f9
for CI
huangxu96 9d5b99d
fix cmake depency problem in CI
huangxu96 42a090f
modified as review suggestion and fix Rocm CI problem.
huangxu96 466cd27
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
huangxu96 44c624c
split this PR into two parts, this is the put_along_axis_op part
huangxu96 79ad7ec
fix a bug in broadcast in python level
huangxu96 1581067
fix a bug in caculate gradient of value
huangxu96 70a2b30
using TensorCopy instead directly assign
huangxu96 6c6c4d0
add inplace API for put_along_axis and unittest.
huangxu96 46e0391
used pre-commit for manipulation.py
huangxu96 c2283c7
fix error message typo and support gpu inplace computation
huangxu96 10deafa
fix error message format problem
huangxu96 e35265e
revert a wrong copy
huangxu96 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,17 +41,42 @@ template <typename tensor_t, typename index_t> | |
| void cpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since I will modify the self tensor in my code, I declare using a const variable is not necessary here. |
||
| const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void cpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor src, const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void cpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor src, const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void cpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor src, const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void cpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor result, | ||
| const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void gpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result, | ||
| const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void gpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor src, const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor src, const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void gpu_scatter_mul_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor src, const platform::DeviceContext& ctx); | ||
|
|
||
| template <typename tensor_t, typename index_t> | ||
| void gpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index, | ||
| Tensor result, | ||
| const platform::DeviceContext& ctx); | ||
| } // namespace operators | ||
| } // namespace paddle | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/fluid/operators/put_along_axis_op.h" | ||
| #include <memory> | ||
| #include <string> | ||
| #include <vector> | ||
| #include "paddle/fluid/framework/ddim.h" | ||
| #include "paddle/fluid/framework/op_version_registry.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| class PutAlongAxisOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| void InferShape(framework::InferShapeContext* ctx) const override { | ||
| OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "PutAlongAxis"); | ||
| OP_INOUT_CHECK(ctx->HasInput("Index"), "Input", "Index", "PutAlongAxis"); | ||
| OP_INOUT_CHECK(ctx->HasInput("Value"), "Input", "Value", "PutAlongAxis"); | ||
| OP_INOUT_CHECK(ctx->HasOutput("Result"), "Output", "Result", | ||
| "PutAlongAxis"); | ||
|
|
||
| auto index_dim = ctx->GetInputDim("Index"); | ||
|
|
||
| ctx->SetOutputDim("Result", index_dim); | ||
| } | ||
|
|
||
| protected: | ||
| framework::OpKernelType GetExpectedKernelType( | ||
| const framework::ExecutionContext& ctx) const override { | ||
| return framework::OpKernelType( | ||
| OperatorWithKernel::IndicateVarDataType(ctx, "Input"), | ||
| ctx.device_context()); | ||
| } | ||
| framework::OpKernelType GetKernelTypeForVar( | ||
| const std::string& var_name, const framework::Tensor& tensor, | ||
| const framework::OpKernelType& expected_kernel_type) const override { | ||
| return framework::OpKernelType(expected_kernel_type.data_type_, | ||
| tensor.place(), tensor.layout()); | ||
| } | ||
| }; | ||
|
|
||
| class PutAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| void Make() override { | ||
| AddInput("Input", "The input tensor of PutAlongAxisOp"); | ||
| AddInput("Index", "The index tensor of PutAlongAxisOp"); | ||
| AddInput("Value", "The value tensor of PutAlongAxisOp"); | ||
| AddOutput("Result", "The result tensor of PutAlongAxisOp"); | ||
| AddAttr<int>("Axis", "The axis that we do PutAlongAxis operation"); | ||
| AddAttr<std::string>("Reduce", "The reduce operation for scatter") | ||
| .SetDefault("assign"); | ||
| AddComment(R"DOC( | ||
| PutAlongAxis Operator.) | ||
| )DOC"); | ||
| } | ||
| }; | ||
|
|
||
| class PutAlongAxisGradOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| void InferShape(framework::InferShapeContext* ctx) const override { | ||
| ctx->SetOutputDim(framework::GradVarName("Input"), | ||
| ctx->GetInputDim("Input")); | ||
| } | ||
|
|
||
| protected: | ||
| framework::OpKernelType GetExpectedKernelType( | ||
| const framework::ExecutionContext& ctx) const override { | ||
| return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( | ||
| ctx, framework::GradVarName("Result")), | ||
| ctx.device_context()); | ||
| } | ||
| framework::OpKernelType GetKernelTypeForVar( | ||
| const std::string& var_name, const framework::Tensor& tensor, | ||
| const framework::OpKernelType& expected_kernel_type) const override { | ||
| return framework::OpKernelType(expected_kernel_type.data_type_, | ||
| tensor.place(), tensor.layout()); | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| class PutAlongAxisGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
| public: | ||
| using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
|
||
| protected: | ||
| void Apply(GradOpPtr<T> op) const override { | ||
| op->SetType("put_along_axis_grad"); | ||
| op->SetInput("Index", this->Input("Index")); | ||
| op->SetInput("Input", this->Input("Input")); | ||
|
|
||
| op->SetInput(framework::GradVarName("Result"), this->OutputGrad("Result")); | ||
| op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); | ||
iclementine marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| op->SetOutput(framework::GradVarName("Value"), this->InputGrad("Value")); | ||
| op->SetAttrMap(this->Attrs()); | ||
| } | ||
| }; | ||
|
|
||
| DECLARE_INPLACE_OP_INFERER(PutAlongAxisInplaceInferer, {"Input", "Result"}); | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
| REGISTER_OPERATOR(put_along_axis, ops::PutAlongAxisOp, ops::PutAlongAxisOpMaker, | ||
| ops::PutAlongAxisGradOpMaker<paddle::framework::OpDesc>, | ||
| ops::PutAlongAxisGradOpMaker<paddle::imperative::OpBase>, | ||
| paddle::operators::PutAlongAxisInplaceInferer); | ||
|
|
||
| REGISTER_OPERATOR(put_along_axis_grad, ops::PutAlongAxisGradOp); | ||
|
|
||
| REGISTER_OP_CPU_KERNEL(put_along_axis, ops::PutAlongAxisOpKernel<float>, | ||
| ops::PutAlongAxisOpKernel<double>, | ||
| ops::PutAlongAxisOpKernel<int>, | ||
| ops::PutAlongAxisOpKernel<uint8_t>, | ||
| ops::PutAlongAxisOpKernel<int64_t>); | ||
|
|
||
| REGISTER_OP_CPU_KERNEL(put_along_axis_grad, | ||
| ops::PutAlongAxisGradOpKernel<float>, | ||
| ops::PutAlongAxisGradOpKernel<double>, | ||
| ops::PutAlongAxisGradOpKernel<int>, | ||
| ops::PutAlongAxisGradOpKernel<uint8_t>, | ||
| ops::PutAlongAxisGradOpKernel<int64_t>); | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to split it into two classes; one for gather and the other for scatter. Reasons:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think split it is better since I think scatter and gather share a lot of logic. Making one kernel function can make the code more elegant. Don't want to make some redundant code.