Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cmake/external/cinn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,3 @@ add_library(cinn SHARED IMPORTED GLOBAL)
set_target_properties(cinn PROPERTIES IMPORTED_LOCATION "${CINN_LIB_LOCATION}/${CINN_LIB_NAME}")
include_directories(${CINN_INCLUDE_DIR})
add_dependencies(cinn external_cinn)

11 changes: 10 additions & 1 deletion paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,17 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
}
outs_map.emplace(var_name_item.first, std::move(out_vars));
}

// set runtime_ctx and infershape_ctx_
instr_node->ResetContext(ins_map, outs_map);
if (instr_node->OpBase()->Type() == "cinn_launch") { // OP use scope in
// kernel
Scope* local_scope = create_local_scope_
? global_scope_->GetMutableLocalScope()
: global_scope_->GetMutableScope();
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
} else {
instr_node->ResetContext(ins_map, outs_map);
}
}

void InterpreterCore::BuildSkipShareLoDInfo() {
Expand Down
17 changes: 14 additions & 3 deletions paddle/fluid/framework/new_executor/interpretercore_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,19 @@ void build_op_func_list(const platform::Place& place,
platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
Scope scope;
Scope* runtime_scope = &scope;
// NOTE(Ruibiao): We do not encourage directly using scope in OP kernel.
// But some OPs do have such behavior (e.g., cinn_launch OP). Here special
// treatment for them.
if (op_with_kernel->Type() == "cinn_launch") {
VLOG(6) << "OP(" << op_with_kernel->Type() << ") use scope in kernel, "
"so pass a real scope to "
"ExecutionContext";
runtime_scope = local_scope;
}

auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
ExecutionContext(*op, scope, *dev_ctx, runtime_context));
ExecutionContext(*op, *runtime_scope, *dev_ctx, runtime_context));
op_with_kernel->ResetKernelType(new OpKernelType(expected_kernel_key));

// change device by the device_guard()
Expand Down Expand Up @@ -439,8 +450,8 @@ void build_op_func_list(const platform::Place& place,
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
}

auto exec_ctx =
ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);
auto exec_ctx = ExecutionContext(*op_with_kernel, *runtime_scope,
*dev_ctx, runtime_context);

auto run_phi_kernel = false;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/framework/new_executor/new_executor_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,16 @@ void Instruction::ResetContext(const VariableValueMap& in_vars,
new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get()));
}

void Instruction::ResetContextWithScope(const VariableValueMap& in_vars,
const VariableValueMap& out_vars,
const framework::Scope& scope) {
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
infershape_ctx_.reset(
new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get()));
execution_ctx_.reset(
new ExecutionContext(*OpBase(), scope, dev_ctx_, *runtime_ctx_.get()));
}

std::shared_ptr<RuntimeContext> Instruction::InnerRuntimeContext() const {
return runtime_ctx_;
}
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/new_executor/new_executor_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ class Instruction {
void ResetContext(const VariableValueMap& in_vars,
const VariableValueMap& out_vars);

void ResetContextWithScope(const VariableValueMap& in_vars,
const VariableValueMap& out_vars,
const framework::Scope& scope);

std::shared_ptr<RuntimeContext> InnerRuntimeContext() const;

std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext()
Expand Down