Skip to content

Commit e68da18

Browse files
authored
add mkldnn int8 pass [step1] (#41579)
* add mkldnn int8 pass * add mkldnn int8 pass * update pass
1 parent 7a07c4a commit e68da18

File tree

4 files changed

+751
-0
lines changed

4 files changed

+751
-0
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ if(WITH_MKLDNN)
140140
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
141141
pass_library(multi_gru_fuse_pass inference DIR mkldnn)
142142
pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn)
143+
pass_library(quant_dequant_mkldnn_pass inference DIR mkldnn)
143144
endif()
144145

145146
if(WITH_IPU)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include "paddle/fluid/framework/ir/graph_helper.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
24+
static void SaveInfoInTheFirstOp(
25+
ir::Graph* graph, const std::string& flag, const std::string& key_suffix,
26+
const std::unordered_map<std::string, std::vector<float>>& info_map) {
27+
VLOG(3) << "save variables in the first op's attr";
28+
29+
const std::string suffix = "_" + key_suffix + "_" + flag;
30+
for (auto* op_node :
31+
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
32+
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
33+
op_node->Op()->Type() == "fetch")
34+
continue;
35+
36+
op_node->Op()->SetAttr(flag, true);
37+
for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) {
38+
op_node->Op()->SetAttr(iter->first + suffix, iter->second);
39+
}
40+
break;
41+
}
42+
}
43+
44+
static void GetInfoFromTheFirstOp(
45+
ir::Graph* graph, const std::string& flag, const std::string& key_suffix,
46+
std::unordered_map<std::string, std::vector<float>>* info_map) {
47+
VLOG(3) << "get variables from the first op's attr";
48+
49+
const std::string suffix = "_" + key_suffix + "_" + flag;
50+
for (auto* op_node :
51+
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
52+
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
53+
op_node->Op()->Type() == "fetch")
54+
continue;
55+
56+
auto* op_desc = op_node->Op();
57+
if (op_desc->GetAttrIfExists<bool>(flag)) {
58+
op_desc->RemoveAttr(flag);
59+
std::vector<std::string> attr_names = op_desc->AttrNames();
60+
for (auto fake_name : attr_names) {
61+
size_t pos = fake_name.find(suffix);
62+
if (pos != std::string::npos) {
63+
std::string name = fake_name.substr(0, pos);
64+
auto scales_vector =
65+
BOOST_GET_CONST(std::vector<float>, op_desc->GetAttr(fake_name));
66+
info_map->insert(std::make_pair(name, scales_vector));
67+
op_desc->RemoveAttr(fake_name);
68+
}
69+
}
70+
break;
71+
}
72+
}
73+
}
74+
75+
} // namespace ir
76+
} // namespace framework
77+
} // namespace paddle

0 commit comments

Comments
 (0)