Skip to content

Commit d68c8c5

Browse files
committed
update pass
1 parent 27308d3 commit d68c8c5

File tree

1 file changed

+36
-13
lines changed

1 file changed

+36
-13
lines changed

paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ void QuantDequantMkldnnPass::CollectInfoFromFake(
9393
auto scale_name = op_desc->Input("Scales")[0];
9494
auto* var = scope->FindVar(scale_name);
9595
PADDLE_ENFORCE_NOT_NULL(
96-
var, "The Scales variable of dequantize op is not found.");
96+
var, platform::errors::NotFound(
97+
"The Scales variable [%s] of dequantize op is not found.",
98+
var));
9799

98100
auto* scale_tensor = var->GetMutable<LoDTensor>();
99101
auto* scale_data = scale_tensor->data<float>();
@@ -132,7 +134,9 @@ void QuantDequantMkldnnPass::CollectInputScalesFromFake(
132134
auto out_var_name = op_desc->Output("Out")[0];
133135
auto* var = scope->FindVar(scale_name);
134136
PADDLE_ENFORCE_NOT_NULL(
135-
var, "The InScale variable of quantize op is not found.");
137+
var,
138+
platform::errors::NotFound(
139+
"The InScale variable [%s] of quantize op is not found.", var));
136140

137141
auto* scale_tensor = var->GetMutable<LoDTensor>();
138142
auto* scale_data = scale_tensor->data<float>();
@@ -196,8 +200,10 @@ void QuantDequantMkldnnPass::CollectFakeQuantizeOps(
196200
for (auto* node_input : op_node->inputs) {
197201
if (node_input->Name() == x_var_name) {
198202
fake_quant_in = node_input;
203+
break;
199204
} else if (node_input->Name() == in_scale_name) {
200205
fake_quant_in_scale = node_input;
206+
break;
201207
}
202208
}
203209

@@ -206,15 +212,22 @@ void QuantDequantMkldnnPass::CollectFakeQuantizeOps(
206212
for (auto* node_output : op_node->outputs) {
207213
if (node_output->Name() == out_var_name) {
208214
fake_quant_out = node_output;
215+
break;
209216
} else if (node_output->Name() == out_scale_name) {
210217
fake_quant_out_scale = node_output;
218+
break;
211219
}
212220
}
213221

214-
PADDLE_ENFORCE_NOT_NULL(fake_quant_in,
215-
"The input var of quantize op is not found.");
216-
PADDLE_ENFORCE_NOT_NULL(fake_quant_out,
217-
"The output var of quantize op is not found.");
222+
PADDLE_ENFORCE_NOT_NULL(
223+
fake_quant_in,
224+
platform::errors::NotFound(
225+
"The input var [%s] of quantize op is not found.", x_var_name));
226+
PADDLE_ENFORCE_NOT_NULL(
227+
fake_quant_out,
228+
platform::errors::NotFound(
229+
"The output var [%s] of quantize op is not found.", out_var_name));
230+
218231
std::string input_act_name = fake_quant_in->Var()->Name();
219232
std::string output_act_name = fake_quant_out->Var()->Name();
220233
auto outlinks = fake_quant_out->outputs;
@@ -241,20 +254,27 @@ void QuantDequantMkldnnPass::CollectFakeDequantizeOps(
241254
for (auto* node_input : op_node->inputs) {
242255
if (node_input->Name() == x_var_name) {
243256
fake_dequant_in = node_input;
257+
break;
244258
}
245259
}
246260

247261
Node* fake_dequant_out = nullptr;
248262
for (auto* node_output : op_node->outputs) {
249263
if (node_output->Name() == out_var_name) {
250264
fake_dequant_out = node_output;
265+
break;
251266
}
252267
}
253268

254-
PADDLE_ENFORCE_NOT_NULL(fake_dequant_in,
255-
"The input var of dequantize op is not found.");
256-
PADDLE_ENFORCE_NOT_NULL(fake_dequant_out,
257-
"The output var of dequantize op is not found.");
269+
PADDLE_ENFORCE_NOT_NULL(
270+
fake_dequant_in,
271+
platform::errors::NotFound(
272+
"The input var [%s] of dequantize op is not found.", x_var_name));
273+
PADDLE_ENFORCE_NOT_NULL(
274+
fake_dequant_out,
275+
platform::errors::NotFound(
276+
"The output var [%s] of dequantize op is not found.", out_var_name));
277+
258278
std::string input_act_name = fake_dequant_in->Var()->Name();
259279
std::string output_act_name = fake_dequant_out->Var()->Name();
260280
auto outlinks = fake_dequant_out->outputs;
@@ -335,8 +355,9 @@ bool QuantDequantMkldnnPass::IsInt8Weight(
335355
auto var_name = op_desc->Input(weight_name)[0];
336356
auto* var = scope->FindVar(var_name);
337357
PADDLE_ENFORCE_NOT_NULL(
338-
var, "The input persistable var of %s op is not found.", op_desc->Type());
339-
358+
var, platform::errors::NotFound(
359+
"The input persistable [%s] var of [%s] op is not found.",
360+
var_name, op_desc->Type()));
340361
auto* weight_tensor = var->GetMutable<LoDTensor>();
341362
auto* weight_data = weight_tensor->data<float>();
342363
bool is_int8 = true;
@@ -371,7 +392,9 @@ void QuantDequantMkldnnPass::DequantizeOpWeights(
371392

372393
auto* var = scope->FindVar(weight_var_name);
373394
PADDLE_ENFORCE_NOT_NULL(
374-
var, "The input persistable var of %s op is not found.", op_desc->Type());
395+
var, platform::errors::NotFound(
396+
"The input persistable [%s] var of [%s] op is not found.",
397+
weight_var_name, op_desc->Type()));
375398
auto* weight_tensor = var->GetMutable<LoDTensor>();
376399
const auto weight_dims = weight_tensor->dims();
377400

0 commit comments

Comments
 (0)