-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Put_along_axis #37921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Put_along_axis #37921
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already a good job. However the code is too long so that I left a lot of comments. Can you split the PR in the future? (You don't have to do it in this PR) It seems the PR exceeding 1000 lines would require Dianhai to review (that may be slow ...)
| ops::TakeAlongAxisOpKernel<double>, | ||
| ops::TakeAlongAxisOpKernel<int>, | ||
| ops::TakeAlongAxisOpKernel<uint8_t>, | ||
| ops::TakeAlongAxisOpKernel<int64_t>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to register for complex64 and complex128?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, since the python API level won't support complex type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is complex type not supported in python?
There are some python APIs with complex dtype support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- There are some suggesting on code style and sementic issues;
- implementation of the gradient of put along aixs seems errorous(both the gradient of input & value should be computed, correctly). I think this is also the reason why it cannot pass grad check.
- add docstring for APIs.
e03d420 to
31715a0
Compare
fc1e7e8 to
ee0c213
Compare
| VLOG(3) << "<<<< Done gpu_scatter_mul_kernel <<<<<"; | ||
| } | ||
|
|
||
| namespace plat = paddle::platform; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems you didn't change what I said...
|
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| void InferShape(framework::InferShapeContext* ctx) const override { | ||
| PADDLE_ENFORCE_EQ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里建议使用OP_INOUT_CHECK
Paddle new APIs: put_along_axis. Xu Huang is on holiday so we created this PR to work on it. It is based on his PR: #37921
PR types
New features
PR changes
APIs
Describe
Paddle new APIs: put_along_axis