Skip to content

Conversation

@Aurelius84
Copy link
Contributor

@Aurelius84 Aurelius84 commented Aug 5, 2024

PR Category

Operator Mechanism

PR Types

Deprecations

Description

Pcard-67164

What's New?

#66089 的基础上,移除了动态图下Yaml中关于 Reshape 算子的 XShape 字段,动静统一。具体修改如下:

1. 规范 yaml 定义

  • sparse_op.yaml :无需改动,已为规范后的不带 XShape 形态
  • dygraph_ops.yaml :移除 xshape 输出字段
  • dygraph_backward.yaml :reshape_grad 和 reshape_double_grad 移除 xshape并改为依赖x ,并标记 no_need_buffer
  • static_ops.yaml : 待移除reshape整个定义,与dygraph_ops.yaml 合并到 ops.yaml中,下个PR做
  • static_backward.yaml :待移除reshape_grad整个定义,与dygraph_backward.yaml 合并到backward.yaml中,下个PR做

注意:ops.yaml 和 backward.yaml 也被旧IR复用了,为了避免 reshape2 这类算子被重复生成定义,需要做屏蔽处理,可以在ops_exclude.yaml 里做屏蔽。

2. 规范 InferMeta 定义

移除 ReshapeGradInferMeta ,新增 SameWithXInferMeta ,方便被Flatten/Squeeze/Unsqueeze 等反向算子给复用。

3. 规范分布式spmd生成

一般而言,分布式的 spmd 生成逻辑默认是使用的 infermeta 相同的入参,但在dist_api_gen.py对于reshape特殊处理了,是由于reshape_grad的kernel参数表里没有使用x。个人觉得应该在reshape_grad的Kernel里显式地依赖x,下游其他组件就不用特殊处理了,故此PR 做了规范。

同时发现,在 reshape.cc 中有一个 ReshapeInferSpmdDynamic,进行了如下处理:

  • 在static_ops.yaml 和dygraph_ops.yaml里默认配置为此,从实现上来看,其是针对 xshape 作了特殊的处理,xshape退场后需要清理
  • 修改上述两个yaml中映射函数为 ReshapeInferSpmd
  • ReshapeInferSpmdDynamic 后缀的函数在 flatten 等切分规则推导中还有使用,需要等算子迁移完之后统一清理

4. 规范前反向拆解逻辑

reshape是基础算子,只需要修改reshape_grad的签名,将 xshape 调整为 x 即可

5. 规范Kernel 命名和注册机制

  • 对于动态图和新IR,将 ReshapeInferKernel 重命名为 ReshapeKernel。
  • 对于老静态图,将 ReshapeKernel 重命名为 ReshapeWithXShapeKernel,并NOTE标记后续将随旧静态图而退场。

对于旧静态图下Reshape2算子的执行,需要正确地将其关联到 ReshapeWithXShapeKernel。从实现上来看,需要在reshape_sig.cc里做修改,比如:

KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
  if (ctx.HasOutput("XShape")) {
    if (ctx.InputSize("ShapeTensor") > 0) {
      return KernelSignature(
          "reshape", {"X"}, {"ShapeTensor"}, {"Out", "XShape"});   // --> 这里要改为reshape_with_xshape
    } else if (ctx.HasInput("Shape")) {
      return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out", "XShape"});
    } else {
      return KernelSignature("reshape", {"X"}, {"shape"}, {"Out", "XShape"});
    }
  } else {
    if (ctx.InputSize("ShapeTensor") > 0) {
      return KernelSignature("reshape_infer", {"X"}, {"ShapeTensor"}, {"Out"});  // ---> 这里规范为reshape 即可
    } else if (ctx.HasInput("Shape")) {
      return KernelSignature("reshape_infer", {"X"}, {"Shape"}, {"Out"});
    } else {
      return KernelSignature("reshape_infer", {"X"}, {"shape"}, {"Out"});
    }
  }
}

注意:在反向的signature中有一个小坑,之前reshape_grad的kernel里是没有显示依赖 x的,原因是旧IR体系下reshape_grad的算子定义里是xshape而非x。但规范reshape_grad kernel后,需要显式地依赖x,因此在reshape_grad上trick了一下,等旧IR算子体系退场时,可一并把这个trick移除。

KernelSignature ReshapeGradOpArgumentMapping(
    const ArgumentMappingContext& ctx UNUSED) {
  return KernelSignature("reshape_grad", {"XShape", "Out@GRAD"}, {}, {"X@GRAD"});
}

6. 动态图 inplace + stride 机制

reshape_ 在inplace 和stride 机制下,移除了xshape后,会导致x_dims获取错误,影响stride kernel 的计算逻辑。

动态图下的Inplace策略是DenseTensor的别名机制,为了避免infer meta阶段错误修改了了输入的Meta信息,导致在Kernel层面获取不到正确的infer_meta,因此优化了api_base.py里中对于InferMeta逻辑的处理。

对于 reshape_这里inplace的API中的部分字段,在infermeta之前做了临时DenseTensor的浅copy,保证infermeta是保留正确的Meta信息。比如,此PR下reshape_的Python C API 代码变动如下:

  Tensor& api_output = x;  // <---- inplace 机制下,out 直接用的x,会导致infermeta后 x 的meta会覆盖
  auto kernel_out = SetKernelOutput(&api_output);

  phi::RecordEvent *infer_shape_record_event = nullptr;
  if(phi::RecordEvent::IsEnabled()){
    infer_shape_record_event = new phi::RecordEvent("reshape infer_meta", phi::TracerEventType::OperatorInner, 1);
  }

  auto origin_input_x = *input_x;     // <----- 此处浅copy了输入Tensor
  phi::MetaTensor meta_out(kernel_out, kernel_result.is_stride_kernel);

  phi::ReshapeInferMeta(MakeMetaTensor(origin_input_x), shape, &meta_out); // <--- 使用origin_input_x来infermeta

  if(infer_shape_record_event != nullptr){
    delete infer_shape_record_event;
  }
  using kernel_signature = void(*)(const phi::DeviceContext&, const phi::DenseTensor&, const phi::IntArray&, phi::DenseTensor*);
  auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
  phi::RecordEvent* kernel_record_event = nullptr;
  if(phi::RecordEvent::IsEnabled()){
    kernel_record_event = new phi::RecordEvent("reshape compute", phi::TracerEventType::OperatorInner, 1);
  }
    (*kernel_fn)(*dev_ctx, origin_input_x, phi::IntArray(shape), kernel_out);
  if(kernel_record_event != nullptr){
    delete kernel_record_event;
  }

@paddle-bot
Copy link

paddle-bot bot commented Aug 5, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

- op : reshape
args : (Tensor x, IntArray shape)
output : Tensor(out), Tensor(xshape)
output : Tensor(out)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如PR描述,下个PR会把dygraph_xxx.yaml和static_xxx.yaml里关于reshape前方向算子的定义都移动到ops/backward.yaml里

Copy link
Contributor

@chen2016013 chen2016013 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

Copy link
Contributor

@winter-wang winter-wang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for rehape inferspmd func.

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Aurelius84 Aurelius84 requested a review from zyfncg August 7, 2024 07:26
Copy link
Contributor

@vivienfanghuagood vivienfanghuagood left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for inference

Copy link
Contributor

@From00 From00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants