Skip to content

Conversation

@ZHUI
Copy link
Collaborator

@ZHUI ZHUI commented Jan 5, 2023

PR types

New features

PR changes

OPs

Describe

Fix 0-dim tensor for arg_min_max op.

@ZHUI ZHUI requested a review from zhwesky2010 January 5, 2023 07:51
if (dtype == phi::TransToProtoVarType(DataType::INT32)) {
int64_t all_element_num = 0;
if (flatten) {
if (x_rank == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个也不用写分支,因为phi::product(x_dims); 里面已经支持了0D的product计算结果是1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

额,这里有个flatten的配置,怕这个 为 false的话,后面可能有问题。

Copy link
Contributor

@zhwesky2010 zhwesky2010 Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0D的时候axis只能为None,就是flatten的情况

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加下axis的检查吧,0D的axis只能为None

// TODO(ZHUI): fix dtype of out
dev_ctx.template Alloc<int64_t>(out);
if (x.dims().size() == 0) {
phi::funcs::set_constant(dev_ctx, out, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XPU的设置为常数:

xpu::constant<T>(
        dev_ctx.x_context(), dx_data, x->numel(), static_cast<T>(1.0));

@ZHUI ZHUI requested a review from zhwesky2010 February 1, 2023 08:47
Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要再微调下单测写法:
单测在TestSundry和TestSundryStatic里分别实现,测试axis=0/-1/None三种情况,测试前向shape与前向值

@zhwesky2010 zhwesky2010 merged commit e4e94a8 into PaddlePaddle:develop Feb 1, 2023
@ZHUI ZHUI deleted the fix_0d_tensor branch February 1, 2023 09:03
pangengzheng pushed a commit to pangengzheng/Paddle that referenced this pull request Feb 2, 2023
* fix 0-d tensor for arg_min_max op.

* fix xpu.

* fix zero dims

* fix

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc

* Update test_zero_dim_tensor.py

* Update test_zero_dim_tensor_xpu.py

* Update test_zero_dim_tensor.py

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants