Skip to content

Commit 8960332

Browse files
chen2016013lixcli
authored andcommitted
[PIR] Refactor Inplace strategy (PaddlePaddle#65491)
* Refactor Inplace * update * handle for tensorarray * update * fix assign_value_ * update * update * update * fix custom meta bug * fix optional value bug
1 parent 1cf0e34 commit 8960332

File tree

12 files changed

+185
-18
lines changed

12 files changed

+185
-18
lines changed

paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,18 @@ CustomKernelInstruction::CustomKernelInstruction(
410410
GetStreamPriority()));
411411
VLOG(6) << "finish process device context";
412412

413+
auto& op_inplace_map = OpMetaInfoHelper::GetInplaceMap(*custom_op_meta_);
414+
for (auto const& pair : op_inplace_map) {
415+
pir::Value input_value =
416+
op->operand_source(yaml_info_parser.InputName2Id().at(pair.first));
417+
pir::Value output_value =
418+
op->result(yaml_info_parser.OutputName2Id().at(pair.second));
419+
if (IsInvalid(output_value) && IsInvalid(input_value)) {
420+
this->AddInplace(value_exec_info_.GetVarByValue(input_value),
421+
value_exec_info_.GetVarByValue(output_value));
422+
}
423+
}
424+
413425
InitInputsOutputsIds(op, value_exec_info_);
414426
VLOG(6) << "finish process inputs outputs index";
415427

@@ -453,6 +465,7 @@ void CustomKernelInstruction::UpdateOutputMeta(
453465
auto out_meta = phi::DenseTensorUtils::GetMutableMeta(out_in_scope);
454466
out_meta->dims = phi::make_ddim(output_shapes[i]);
455467
out_meta->dtype = output_dtypes[i];
468+
out_meta->strides = out_meta->calc_strides(out_meta->dims);
456469
}
457470
}
458471

@@ -504,7 +517,9 @@ void CustomKernelInstruction::Run() {
504517
vec_input_name2id_map_,
505518
custom_attrs_);
506519
UpdateOutputMeta(output_shapes, output_dtypes);
507-
520+
for (auto& pair : this->InplaceInfo()) {
521+
ShareVarBuffer(pair.first, pair.second);
522+
}
508523
VLOG(6) << "Run custom op " << custom_op_name_ << " kernel.";
509524
kernel_func_(&custom_kernel_ctx_);
510525
}

paddle/fluid/framework/new_executor/instruction/instruction_base.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,12 @@ const std::vector<Variable*>& InstructionBase::EagerGCVars() const {
273273

274274
void InstructionBase::ClearEagerGCVars() { eager_gc_vars_.clear(); }
275275

276-
const std::vector<std::pair<Variable*, Variable*>>&
276+
const std::vector<std::pair<const Variable*, Variable*>>&
277277
InstructionBase::InplaceInfo() const {
278278
return vec_inplace_in_to_out_;
279279
}
280280

281-
void InstructionBase::AddInplace(Variable* in, Variable* out) {
281+
void InstructionBase::AddInplace(const Variable* in, Variable* out) {
282282
vec_inplace_in_to_out_.emplace_back(in, out);
283283
}
284284

@@ -334,6 +334,17 @@ void InstructionBase::InitInputsOutputsIds(
334334
outputs.emplace(value, outputs_id);
335335
}
336336
}
337+
338+
const auto value_2_var_name_map = value_exec_info.GetValue2VarName();
339+
for (auto inplace_var_pair : this->InplaceInfo()) {
340+
for (auto item : value_2_var_name_map) {
341+
if (item.second == value_exec_info.GetVarName(inplace_var_pair.first)) {
342+
std::vector<int> outputs_id = GetValueIds(item.first, value_exec_info);
343+
outputs.emplace(item.first, outputs_id);
344+
break;
345+
}
346+
}
347+
}
337348
SetOutputs(outputs);
338349
VLOG(8) << "finish process outputs_index";
339350
}

paddle/fluid/framework/new_executor/instruction/instruction_base.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ class InstructionBase {
127127
void AddEagerGCVar(Variable* var);
128128
void ClearEagerGCVars();
129129

130-
const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const;
131-
void AddInplace(Variable* in, Variable* out);
130+
const std::vector<std::pair<const Variable*, Variable*>>& InplaceInfo() const;
131+
void AddInplace(const Variable* in, Variable* out);
132132
void ClearInplace();
133133

134134
std::map<int, int>& GetMutableInplaceBackMap() { return inplace_back_map_; }
@@ -207,7 +207,7 @@ class InstructionBase {
207207

208208
std::vector<Variable*> eager_gc_vars_;
209209

210-
std::vector<std::pair<Variable*, Variable*>>
210+
std::vector<std::pair<const Variable*, Variable*>>
211211
vec_inplace_in_to_out_; // If not use share data, need this ?
212212

213213
std::map<int, int> inplace_back_map_;

paddle/fluid/framework/new_executor/instruction/instruction_util.cc

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,4 +406,106 @@ bool GetCondData(const phi::DenseTensor& cond) {
406406
return cpu_cond->data<bool>()[0];
407407
}
408408

409+
// NOTE(chenxi67): Here, we only perform inplace processing for variables whose
410+
// type is NOT TensorArray. It has already been processed in the previous
411+
// step(HandleForInplaceVarOp).
412+
void HandleForInplaceOp(pir::Operation* op,
413+
const ValueExecutionInfo* value_exe_info,
414+
InstructionBase* instr) {
415+
if (op->num_results() < 1) return;
416+
pir::IrContext* ctx = pir::IrContext::Instance();
417+
std::string op_name = op->name();
418+
if (op->attributes().count("op_name")) {
419+
op_name =
420+
op->attributes().at("op_name").dyn_cast<pir::StrAttribute>().AsString();
421+
}
422+
423+
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
424+
paddle::dialect::OpYamlInfoParser yaml_parser(
425+
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>()
426+
->get_op_info_(op_name),
427+
paddle::dialect::IsLegacyOp(op_name));
428+
429+
for (size_t i = 0; i < op->num_results(); ++i) {
430+
pir::Value value = op->result(i);
431+
if (!IsInvalid(value)) {
432+
VLOG(8) << "Number " << i << " result of " << op_name
433+
<< " is not invalid, so skip build a variable.";
434+
continue;
435+
}
436+
if (IsNeedVarInplace(op, value, op_name)) {
437+
continue;
438+
}
439+
std::string value_name = yaml_parser.OutputNames()[i];
440+
if (yaml_parser.HasInplace(value_name)) {
441+
const std::string& inplace_name = yaml_parser.InplaceName(value_name);
442+
pir::Value inplace_value =
443+
op->operand_source(yaml_parser.InputName2Id().at(inplace_name));
444+
std::string input_var_name = value_exe_info->GetVarName(inplace_value);
445+
std::string output_var_name = value_exe_info->GetVarName(value);
446+
PADDLE_ENFORCE_NE(input_var_name,
447+
"",
448+
phi::errors::InvalidArgument(
449+
"The input var name of inplace op is empty."));
450+
PADDLE_ENFORCE_NE(output_var_name,
451+
"",
452+
phi::errors::InvalidArgument(
453+
"The output var name of inplace op is empty."));
454+
VLOG(4) << "inplace: " << value_name << " -> " << inplace_name
455+
<< " (var: " << input_var_name << ")";
456+
instr->AddInplace(value_exe_info->GetVarByValue(inplace_value),
457+
value_exe_info->GetVarByValue(value));
458+
} else if (yaml_parser.HasView(value_name)) {
459+
const std::string& view_name = yaml_parser.ViewName(value_name);
460+
pir::Value view_value =
461+
op->operand_source(yaml_parser.InputName2Id().at(view_name));
462+
// const std::string& var_name = value_2_var_name->at(view_value);
463+
std::string input_var_name = value_exe_info->GetVarName(view_value);
464+
std::string output_var_name = value_exe_info->GetVarName(value);
465+
466+
PADDLE_ENFORCE_NE(input_var_name,
467+
"",
468+
platform::errors::InvalidArgument(
469+
"The input var name of view op is empty."));
470+
PADDLE_ENFORCE_NE(output_var_name,
471+
"",
472+
platform::errors::InvalidArgument(
473+
"The output var name of view op is empty."));
474+
VLOG(4) << "view: " << value_name << " -> " << view_name
475+
<< " (var: " << input_var_name << ")";
476+
instr->AddInplace(value_exe_info->GetVarByValue(view_value),
477+
value_exe_info->GetVarByValue(value));
478+
}
479+
}
480+
}
481+
482+
void ShareVarBuffer(const Variable* src_var, Variable* dst_var) {
483+
if (src_var->IsType<phi::DenseTensor>()) {
484+
auto& src_tensor = src_var->Get<phi::DenseTensor>();
485+
auto* tmp_dst_tensor = dst_var->GetMutable<phi::DenseTensor>();
486+
tmp_dst_tensor->ShareBufferWith(src_tensor);
487+
return;
488+
} else if (src_var->IsType<phi::SelectedRows>()) {
489+
auto* tmp_dst_slr = dst_var->GetMutable<phi::SelectedRows>();
490+
auto* dst_t = tmp_dst_slr->mutable_value();
491+
auto& src_slr = src_var->Get<phi::SelectedRows>();
492+
auto& src_t = src_slr.value();
493+
dst_t->ShareBufferWith(src_t);
494+
return;
495+
} else if (src_var->IsType<VariableRefArray>()) {
496+
auto src_var_array = src_var->Get<VariableRefArray>();
497+
auto* dst_var_array = dst_var->GetMutable<VariableRefArray>();
498+
for (size_t i = 0; i < src_var_array.size(); ++i) {
499+
Variable* copy_var = const_cast<Variable*>(dst_var_array->at(i));
500+
ShareVarBuffer(src_var_array.at(i), copy_var);
501+
}
502+
return;
503+
} else {
504+
PADDLE_THROW(phi::errors::PreconditionNotMet(
505+
"Output only support DenseTensorType "
506+
"or SelectedRowsType or VariableRefArray"));
507+
}
508+
return;
509+
}
510+
409511
} // namespace paddle::framework

paddle/fluid/framework/new_executor/instruction/instruction_util.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,10 @@ void InsertInplacedExternalInputsToOuts(
6464

6565
bool GetCondData(const phi::DenseTensor& cond);
6666

67+
void HandleForInplaceOp(pir::Operation* op,
68+
const ValueExecutionInfo* value_exe_info,
69+
InstructionBase* instr);
70+
71+
void ShareVarBuffer(const Variable* src_var, Variable* dst_var);
6772
} // namespace framework
6873
} // namespace paddle

paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ LegacyKernelInstruction::LegacyKernelInstruction(
163163

164164
VLOG(6) << "finish process kernel context";
165165

166+
if (op->attributes().count("is_inplace") != 0 &&
167+
op->attributes().at("is_inplace").dyn_cast<pir::BoolAttribute>().data()) {
168+
HandleForInplaceOp(op, value_exec_info_, this);
169+
}
170+
166171
InitInputsOutputsIds(op, *value_exec_info);
167172
VLOG(6) << "finish process inputs outputs index";
168173

@@ -185,6 +190,9 @@ void LegacyKernelInstruction::Run() {
185190
if (infer_meta_interface_) {
186191
infer_meta_interface_->infer_meta_(&(infer_meta_context_));
187192
}
193+
for (auto& pair : this->InplaceInfo()) {
194+
ShareVarBuffer(pair.first, pair.second);
195+
}
188196
VLOG(6) << "Run op " << legacy_op_name_ << " kernel.";
189197
(*(phi_kernel_))((kernel_context_));
190198
}

paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,10 @@ PhiKernelInstruction::PhiKernelInstruction(
160160

161161
kernel_context_.SetDeviceContext(dev_ctx);
162162
VLOG(6) << "finish process kernel context";
163-
163+
if (op->attributes().count("is_inplace") != 0 &&
164+
op->attributes().at("is_inplace").dyn_cast<pir::BoolAttribute>().data()) {
165+
HandleForInplaceOp(op, value_exec_info_, this);
166+
}
164167
InitInputsOutputsIds(op, *value_exec_info);
165168
VLOG(6) << "finish process inputs outputs index";
166169

@@ -181,6 +184,9 @@ void PhiKernelInstruction::Run() {
181184
infer_meta_interface_->infer_meta_(&(infer_meta_context_));
182185
}
183186
VLOG(6) << "End run op " << phi_op_name_ << " infer meta.";
187+
for (auto& pair : this->InplaceInfo()) {
188+
ShareVarBuffer(pair.first, pair.second);
189+
}
184190
VLOG(6) << "Begin run op " << phi_op_name_ << " kernel.";
185191
(*(phi_kernel_))(&(kernel_context_));
186192
VLOG(6) << "End run op " << phi_op_name_ << " kernel.";

paddle/fluid/framework/new_executor/new_executor_defs.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,12 @@ const platform::DeviceContext& Instruction::DeviceContext() const {
306306
return dev_ctx_;
307307
}
308308

309-
const std::vector<std::pair<Variable*, Variable*>>& Instruction::InplaceInfo()
310-
const {
309+
const std::vector<std::pair<const Variable*, Variable*>>&
310+
Instruction::InplaceInfo() const {
311311
return vec_inplace_in_to_out_;
312312
}
313313

314-
void Instruction::AddInplace(Variable* in, Variable* out) {
314+
void Instruction::AddInplace(const Variable* in, Variable* out) {
315315
vec_inplace_in_to_out_.emplace_back(in, out);
316316
}
317317

paddle/fluid/framework/new_executor/new_executor_defs.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,9 @@ class Instruction {
295295

296296
const platform::DeviceContext& DeviceContext() const;
297297

298-
const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const;
298+
const std::vector<std::pair<const Variable*, Variable*>>& InplaceInfo() const;
299299

300-
void AddInplace(Variable* in, Variable* out);
300+
void AddInplace(const Variable* in, Variable* out);
301301

302302
void ClearInplace();
303303

@@ -340,7 +340,7 @@ class Instruction {
340340

341341
std::vector<size_t> gc_check_vars_;
342342

343-
std::vector<std::pair<Variable*, Variable*>> vec_inplace_in_to_out_;
343+
std::vector<std::pair<const Variable*, Variable*>> vec_inplace_in_to_out_;
344344

345345
bool pre_define_context_{false};
346346
};

paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,21 @@ void HandleForSpecialOp(pir::Operation* op,
682682
}
683683
}
684684

685-
void HandleForInplaceOp(pir::Operation* op,
686-
const std::string& var_name_prefix,
687-
ValueExecutionInfo* value_exe_info) {
685+
bool IsNeedVarInplace(pir::Operation* op,
686+
pir::Value value,
687+
std::string op_name) {
688+
return (value.type().isa<paddle::dialect::DenseTensorArrayType>() ||
689+
op_name == "pd_op.assign_value_");
690+
}
691+
692+
// NOTE(chenxi67): Here, we only perform inplace processing for variables that
693+
// need to be inplaced by var (mostly, whose type is TensorArray or re-Allocated
694+
// Densetensor). For other types of variables, we only share the holder of
695+
// DenseTensor but not the var*. The reason is that vector<DenseTensor> in
696+
// TensorArray (or re-Allocated Densetensor) cannot be shared totally.
697+
void HandleForInplaceVarOp(pir::Operation* op,
698+
const std::string& var_name_prefix,
699+
ValueExecutionInfo* value_exe_info) {
688700
if (op->num_results() < 1) return;
689701
pir::IrContext* ctx = pir::IrContext::Instance();
690702
std::string op_name = op->name();
@@ -706,6 +718,10 @@ void HandleForInplaceOp(pir::Operation* op,
706718
<< " is not invalid, so skip build a variable.";
707719
continue;
708720
}
721+
if (!IsNeedVarInplace(op, value, op_name)) {
722+
BuildValue(value, var_name_prefix, value_exe_info);
723+
continue;
724+
}
709725
std::string value_name = yaml_parser.OutputNames()[i];
710726
if (yaml_parser.HasInplace(value_name)) {
711727
const std::string& inplace_name = yaml_parser.InplaceName(value_name);
@@ -785,7 +801,7 @@ void BuildScope(const pir::Block& block,
785801
.at("is_inplace")
786802
.dyn_cast<pir::BoolAttribute>()
787803
.data()) {
788-
HandleForInplaceOp(&op, var_name_prefix, value_exe_info);
804+
HandleForInplaceVarOp(&op, var_name_prefix, value_exe_info);
789805
continue;
790806
} else {
791807
for (size_t i = 0; i < op.num_results(); ++i) {

0 commit comments

Comments
 (0)