-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Operator]Unify dygraph reshape op yaml to delete xshape output #67009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| - op : reshape | ||
| args : (Tensor x, IntArray shape) | ||
| output : Tensor(out), Tensor(xshape) | ||
| output : Tensor(out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如PR描述,下个PR会把dygraph_xxx.yaml和static_xxx.yaml里关于reshape前方向算子的定义都移动到ops/backward.yaml里
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for rehape inferspmd func.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for inference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Operator Mechanism
PR Types
Deprecations
Description
Pcard-67164
What's New?
在 #66089 的基础上,移除了动态图下Yaml中关于 Reshape 算子的 XShape 字段,动静统一。具体修改如下:
1. 规范 yaml 定义
sparse_op.yaml:无需改动,已为规范后的不带 XShape 形态dygraph_ops.yaml:移除xshape输出字段dygraph_backward.yaml:reshape_grad 和 reshape_double_grad 移除 xshape并改为依赖x ,并标记 no_need_bufferstatic_ops.yaml: 待移除reshape整个定义,与dygraph_ops.yaml 合并到 ops.yaml中,下个PR做static_backward.yaml:待移除reshape_grad整个定义,与dygraph_backward.yaml 合并到backward.yaml中,下个PR做2. 规范 InferMeta 定义
移除
ReshapeGradInferMeta,新增SameWithXInferMeta,方便被Flatten/Squeeze/Unsqueeze 等反向算子给复用。3. 规范分布式spmd生成
一般而言,分布式的 spmd 生成逻辑默认是使用的 infermeta 相同的入参,但在dist_api_gen.py对于reshape特殊处理了,是由于reshape_grad的kernel参数表里没有使用x。个人觉得应该在reshape_grad的Kernel里显式地依赖x,下游其他组件就不用特殊处理了,故此PR 做了规范。
同时发现,在 reshape.cc 中有一个
ReshapeInferSpmdDynamic,进行了如下处理:4. 规范前反向拆解逻辑
reshape是基础算子,只需要修改reshape_grad的签名,将 xshape 调整为 x 即可
5. 规范Kernel 命名和注册机制
对于旧静态图下Reshape2算子的执行,需要正确地将其关联到 ReshapeWithXShapeKernel。从实现上来看,需要在reshape_sig.cc里做修改,比如:
6. 动态图 inplace + stride 机制
reshape_ 在inplace 和stride 机制下,移除了xshape后,会导致x_dims获取错误,影响stride kernel 的计算逻辑。
动态图下的Inplace策略是DenseTensor的别名机制,为了避免infer meta阶段错误修改了了输入的Meta信息,导致在Kernel层面获取不到正确的infer_meta,因此优化了api_base.py里中对于InferMeta逻辑的处理。
对于 reshape_这里inplace的API中的部分字段,在infermeta之前做了临时DenseTensor的浅copy,保证infermeta是保留正确的Meta信息。比如,此PR下reshape_的Python C API 代码变动如下: