-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Open
Labels
Description
请提出你的问题 Please ask your question
#67009 中为了避免infer meta阶段错误修改了输入的Meta信息而在infermeta之前做了临时DenseTensor的浅copy,给的代码示例如下:
Tensor& api_output = x; // <---- inplace 机制下,out 直接用的x,会导致infermeta后 x 的meta会覆盖
auto kernel_out = SetKernelOutput(&api_output);
phi::RecordEvent *infer_shape_record_event = nullptr;
if(phi::RecordEvent::IsEnabled()){
infer_shape_record_event = new phi::RecordEvent("reshape infer_meta", phi::TracerEventType::OperatorInner, 1);
}
auto origin_input_x = *input_x; // <----- 此处浅copy了输入Tensor
phi::MetaTensor meta_out(kernel_out, kernel_result.is_stride_kernel);
phi::ReshapeInferMeta(MakeMetaTensor(origin_input_x), shape, &meta_out); // <--- 使用origin_input_x来infermeta
if(infer_shape_record_event != nullptr){
delete infer_shape_record_event;
}
using kernel_signature = void(*)(const phi::DeviceContext&, const phi::DenseTensor&, const phi::IntArray&, phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
phi::RecordEvent* kernel_record_event = nullptr;
if(phi::RecordEvent::IsEnabled()){
kernel_record_event = new phi::RecordEvent("reshape compute", phi::TracerEventType::OperatorInner, 1);
}
(*kernel_fn)(*dev_ctx, origin_input_x, phi::IntArray(shape), kernel_out);
if(kernel_record_event != nullptr){
delete kernel_record_event;
}但其中似乎存在问题,即调用kernel时的语句 (*kernel_fn)(*dev_ctx, origin_input_x, phi::IntArray(shape), kernel_out); 的输入使用的是浅拷贝后的tensororigin_input_x,输出是原tensor的指针kernel_out,这就导致实际上输入输出并不是同一个tensor(虽然共享allocation),这会给kernel中的一些inplace判断带来麻烦。所以在调用kernel时是否应该将输入改为原tensor,即通过 (*kernel_fn)(*dev_ctx, *input_x, phi::IntArray(shape), kernel_out); 来调用?辛苦@Aurelius84 帮忙解答一下,谢谢!