Skip to content

Commit 79203ec

Browse files
wanghuancoderpiotrekobi
authored andcommitted
fix some bug, test=develop (PaddlePaddle#36888)
1 parent cb9de59 commit 79203ec

File tree

4 files changed

+57
-10
lines changed

4 files changed

+57
-10
lines changed

paddle/fluid/framework/new_executor/interpretercore.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,14 @@ void InterpreterCore::BuildInplace() {
241241
auto& outputs = instr.Outputs();
242242
for (auto& pair : in_to_outs) {
243243
auto iter = inputs.find(pair.first);
244-
if (iter != inputs.end()) {
244+
if (iter != inputs.end() && !iter->second.empty()) {
245245
if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) {
246246
auto iterout = outputs.find(pair.second);
247-
if (iterout != outputs.end()) {
247+
if (iterout != outputs.end() && !iterout->second.empty()) {
248248
auto invar = global_scope_->Var(iter->second[0]);
249249
auto outvar = global_scope_->Var(iterout->second[0]);
250-
if (invar && outvar) {
250+
if (invar && outvar && invar->IsType<LoDTensor>() &&
251+
outvar->IsType<LoDTensor>()) {
251252
instr.AddInplace(invar, outvar);
252253
VLOG(3) << "inplace " << vec_instruction_[i].OpBase()->Type()
253254
<< " " << global_scope_->GetNameById(iter->second[0])

paddle/fluid/framework/new_executor/interpretercore_util.cc

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ void build_variable_scope(const framework::ProgramDesc& pdesc,
142142
if (nullptr == var_scope->FindVar(var_name)) {
143143
var_scope->AddVar(var_desc->Name(), var_desc);
144144
} else {
145-
auto* var_desc = var_scope->VarDesc(var_name);
146-
if (nullptr == var_desc) {
145+
auto* var_desc_tmp = var_scope->VarDesc(var_name);
146+
if (nullptr == var_desc_tmp) {
147147
VLOG(3) << "update var:" << var_name << " desc from nullptr into "
148148
<< var_desc;
149149
var_scope->VarMetaInfo(var_name).vardesc_ = var_desc;
@@ -206,9 +206,22 @@ void apply_device_guard(const OperatorBase* op_base,
206206
VLOG(3) << "Switch into CPUPlace by device_guard.";
207207
expected_kernel_key->place_ = platform::CPUPlace();
208208
} else if (op_device.find("gpu") != std::string::npos &&
209-
platform::is_gpu_place(place)) {
210-
VLOG(3) << "Switch into " << place << " by device_guard.";
211-
expected_kernel_key->place_ = place;
209+
(platform::is_gpu_place(place) ||
210+
platform::is_npu_place(place))) {
211+
// when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
212+
// will be executed and a warning will be given at the same time.
213+
if (op_base->SupportGPU()) {
214+
expected_kernel_key->place_ = place;
215+
} else if (op_base->SupportNPU()) {
216+
expected_kernel_key->place_ = place;
217+
} else {
218+
expected_kernel_key->place_ = platform::CPUPlace();
219+
LOG_FIRST_N(WARNING, 1)
220+
<< "Op(" << op_base->Type()
221+
<< ") has no CUDA implementation. It will be assigned to CPUPlace.";
222+
}
223+
VLOG(3) << "Switch into " << expected_kernel_key->place_
224+
<< " by device_guard.";
212225
} else {
213226
PADDLE_THROW(
214227
platform::errors::Fatal("Unsupported current place %s", op_device));

paddle/fluid/framework/new_executor/new_executor_defs.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,15 @@ struct VariableMetaInfo {
474474
// TODO(zhiqiu): Maybe we need to add rwlock for VariableScope?
475475
class VariableScope : public ScopeBase {
476476
public:
477+
VariableScope() {
478+
// for @EMPTY@ variable
479+
var_list_.push_back(nullptr);
480+
name2id_[kEmptyVarName] = 0;
481+
VariableMetaInfo info;
482+
info.var_ref_count_ = 0;
483+
info.vardesc_ = nullptr;
484+
vec_meta_info_.push_back(info);
485+
}
477486
Variable* FindVar(const std::string& name) const {
478487
auto it = name2id_.find(name);
479488
if (it != name2id_.end()) {

paddle/fluid/operators/controlflow/fetch_v2_op.cc

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,35 @@ class FetchV2Op : public framework::OperatorWithKernel {
7777
framework::OpKernelType GetKernelTypeForVar(
7878
const std::string &var_name, const framework::Tensor &tensor,
7979
const framework::OpKernelType &expected_kernel_type) const override {
80+
if (!tensor.IsInitialized()) {
81+
return expected_kernel_type;
82+
}
8083
return framework::OpKernelType(expected_kernel_type.data_type_,
8184
tensor.place(), tensor.layout());
8285
}
8386

8487
framework::OpKernelType GetExpectedKernelType(
8588
const framework::ExecutionContext &ctx) const override {
89+
auto *fetch_var = ctx.InputVar("X");
90+
if (fetch_var == nullptr) {
91+
return framework::OpKernelType(framework::proto::VarType::FP32,
92+
platform::CPUPlace());
93+
}
94+
95+
if (fetch_var->IsType<framework::LoDTensor>()) {
96+
auto &src_item = fetch_var->Get<framework::LoDTensor>();
97+
if (!src_item.IsInitialized()) {
98+
return framework::OpKernelType(framework::proto::VarType::FP32,
99+
platform::CPUPlace());
100+
}
101+
} else {
102+
auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
103+
if (src_item.empty() || !src_item[0].IsInitialized()) {
104+
return framework::OpKernelType(framework::proto::VarType::FP32,
105+
platform::CPUPlace());
106+
}
107+
}
108+
86109
return framework::OpKernelType(
87110
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
88111
platform::CPUPlace());
@@ -127,6 +150,9 @@ class FetchV2Kernel {
127150

128151
if (fetch_var->IsType<framework::LoDTensor>()) {
129152
auto &src_item = fetch_var->Get<framework::LoDTensor>();
153+
if (!src_item.IsInitialized()) {
154+
return;
155+
}
130156
auto *dst_item = &(BOOST_GET(framework::LoDTensor, fetch_list->at(col)));
131157
bool check_place = platform::is_cpu_place(src_item.place()) ||
132158
platform::is_cuda_pinned_place(src_item.place());
@@ -173,9 +199,7 @@ class FetchV2OpProtoMaker : public framework::OpProtoAndCheckerMaker {
173199
.SetDefault(true);
174200
AddComment(R"DOC(
175201
FetchV2 Operator.
176-
177202
It should not be configured by users directly.
178-
179203
)DOC");
180204
}
181205
};

0 commit comments

Comments
 (0)