Skip to content

当前inplace的kernel计算逻辑是否存在问题? #70853

@USTCKAY

Description

@USTCKAY

请提出你的问题 Please ask your question

#67009 中为了避免infer meta阶段错误修改了输入的Meta信息而在infermeta之前做了临时DenseTensor的浅copy,给的代码示例如下:

  Tensor& api_output = x;  // <---- inplace 机制下,out 直接用的x,会导致infermeta后 x 的meta会覆盖
  auto kernel_out = SetKernelOutput(&api_output);

  phi::RecordEvent *infer_shape_record_event = nullptr;
  if(phi::RecordEvent::IsEnabled()){
    infer_shape_record_event = new phi::RecordEvent("reshape infer_meta", phi::TracerEventType::OperatorInner, 1);
  }

  auto origin_input_x = *input_x;     // <----- 此处浅copy了输入Tensor
  phi::MetaTensor meta_out(kernel_out, kernel_result.is_stride_kernel);

  phi::ReshapeInferMeta(MakeMetaTensor(origin_input_x), shape, &meta_out); // <--- 使用origin_input_x来infermeta

  if(infer_shape_record_event != nullptr){
    delete infer_shape_record_event;
  }
  using kernel_signature = void(*)(const phi::DeviceContext&, const phi::DenseTensor&, const phi::IntArray&, phi::DenseTensor*);
  auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
  phi::RecordEvent* kernel_record_event = nullptr;
  if(phi::RecordEvent::IsEnabled()){
    kernel_record_event = new phi::RecordEvent("reshape compute", phi::TracerEventType::OperatorInner, 1);
  }
    (*kernel_fn)(*dev_ctx, origin_input_x, phi::IntArray(shape), kernel_out);
  if(kernel_record_event != nullptr){
    delete kernel_record_event;
  }

但其中似乎存在问题,即调用kernel时的语句 (*kernel_fn)(*dev_ctx, origin_input_x, phi::IntArray(shape), kernel_out); 的输入使用的是浅拷贝后的tensororigin_input_x,输出是原tensor的指针kernel_out,这就导致实际上输入输出并不是同一个tensor(虽然共享allocation),这会给kernel中的一些inplace判断带来麻烦。所以在调用kernel时是否应该将输入改为原tensor,即通过 (*kernel_fn)(*dev_ctx, *input_x, phi::IntArray(shape), kernel_out); 来调用?辛苦@Aurelius84 帮忙解答一下,谢谢!

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions