Skip to content

Commit 08fb7c8

Browse files
committed
Weight quantization skip conv_conv_fuse_pass, test=develop
1 parent a17d7be commit 08fb7c8

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

lite/core/mir/fusion/conv_conv_fuse_pass.cc

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "lite/core/mir/fusion/conv_conv_fuse_pass.h"
16+
#include <list>
1617
#include <memory>
1718
#include <vector>
1819
#include "lite/core/mir/fusion/conv_conv_fuser.h"
@@ -27,13 +28,10 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
2728
// initialze fuser params
2829
std::vector<bool> conv_has_bias_cases{true, false};
2930
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"};
30-
bool has_fp32 = false;
3131
bool has_int8 = false;
32+
bool has_weight_quant = false;
3233
for (auto& place : graph->valid_places()) {
3334
if (place.target == TARGET(kARM) || place.target == TARGET(kHost)) {
34-
if (place.precision == PRECISION(kFloat)) {
35-
has_fp32 = true;
36-
}
3735
if (place.precision == PRECISION(kInt8)) {
3836
has_int8 = true;
3937
}
@@ -42,8 +40,18 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
4240
return;
4341
}
4442
}
43+
const std::list<mir::Node>& nodes = graph->nodes();
44+
for (auto& node : nodes) {
45+
if (node.IsStmt()) {
46+
auto* op_info = (node.stmt())->op_info();
47+
if (op_info->HasAttr("quantization_type")) {
48+
has_weight_quant = true;
49+
break;
50+
}
51+
}
52+
}
4553
// only support arm-fp32
46-
if (has_int8 || (has_fp32 && has_int8)) {
54+
if (has_int8 || has_weight_quant) {
4755
return;
4856
}
4957
// only support fp32 fusion

0 commit comments

Comments
 (0)