@@ -1053,6 +1053,67 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
10531053 PrettyLogDetail (" --- quantized %d fusion_lstm ops" , quantize_count);
10541054}
10551055
1056+ void CPUQuantizePass::QuantizeNearestInterp (Graph* graph) const {
1057+ GraphPatternDetector gpd;
1058+ auto pattern = gpd.mutable_pattern ();
1059+ patterns::NearestInterp nearest_interp_pattern{pattern, name_scope_};
1060+ nearest_interp_pattern ();
1061+
1062+ int quantize_nearest_interp_count = 0 ;
1063+ auto handler = [&](const GraphPatternDetector::subgraph_t & subgraph,
1064+ Graph* g) {
1065+ VLOG (4 ) << " Quantize nearest_interp op" ;
1066+ GET_IR_NODE_FROM_SUBGRAPH (nearest_interp_op, nearest_interp_op,
1067+ nearest_interp_pattern);
1068+
1069+ // skip if should not be quantized
1070+ if (!platform::HasOpINT8DataType (nearest_interp_op->Op ())) {
1071+ LogQuantizationDisabled (nearest_interp_op);
1072+ return ;
1073+ }
1074+ GET_IR_NODE_FROM_SUBGRAPH (prev_op, prev_op, nearest_interp_pattern);
1075+ GET_IR_NODE_FROM_SUBGRAPH (next_op, next_op, nearest_interp_pattern);
1076+
1077+ // skip if prev op and next op is not quantized
1078+ if (!(IsOpDequantized (prev_op)) && !(IsOpQuantized (next_op))) {
1079+ LogCannotQuantizeOp (nearest_interp_op,
1080+ " There are no other quantized operators nearby, so "
1081+ " quantization is not recommended." );
1082+ return ;
1083+ }
1084+
1085+ GET_IR_NODE_FROM_SUBGRAPH (nearest_interp_in, nearest_interp_in,
1086+ nearest_interp_pattern);
1087+ GET_IR_NODE_FROM_SUBGRAPH (nearest_interp_out, nearest_interp_out,
1088+ nearest_interp_pattern);
1089+
1090+ if (!AreScalesPresentForNodes ({nearest_interp_in, nearest_interp_out})) {
1091+ LogCannotQuantizeOp (nearest_interp_op);
1092+ return ;
1093+ }
1094+
1095+ bool is_input_unsigned{false };
1096+ auto input_scale =
1097+ GetScaleValueForNode (nearest_interp_in, &is_input_unsigned);
1098+ QuantizeInput (g, nearest_interp_op, nearest_interp_in, " X" , input_scale,
1099+ is_input_unsigned);
1100+
1101+ bool is_output_unsigned{false };
1102+ auto output_scale =
1103+ GetScaleValueForNode (nearest_interp_out, &is_output_unsigned);
1104+ DequantizeOutput (g, nearest_interp_op, nearest_interp_out, " Out" ,
1105+ output_scale, is_output_unsigned);
1106+
1107+ ++quantize_nearest_interp_count;
1108+ };
1109+
1110+ gpd (graph, handler);
1111+ AddStatis (quantize_nearest_interp_count);
1112+
1113+ PrettyLogDetail (" --- quantized %d nearest_interp ops" ,
1114+ quantize_nearest_interp_count);
1115+ }
1116+
10561117void CPUQuantizePass::ApplyImpl (ir::Graph* graph) const {
10571118 VLOG (3 ) << " Quantizing the graph." ;
10581119 PADDLE_ENFORCE_NOT_NULL (
@@ -1076,6 +1137,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
10761137 QuantizeMultiGru (graph);
10771138 QuantizeFusionLSTM (graph);
10781139 QuantizeSlice (graph);
1140+ QuantizeNearestInterp (graph);
10791141}
10801142
10811143} // namespace ir
0 commit comments