File tree Expand file tree Collapse file tree 1 file changed +13
-5
lines changed Expand file tree Collapse file tree 1 file changed +13
-5
lines changed Original file line number Diff line number Diff line change 13
13
// limitations under the License.
14
14
15
15
#include " lite/core/mir/fusion/conv_conv_fuse_pass.h"
16
+ #include < list>
16
17
#include < memory>
17
18
#include < vector>
18
19
#include " lite/core/mir/fusion/conv_conv_fuser.h"
@@ -27,13 +28,10 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
27
28
// initialze fuser params
28
29
std::vector<bool > conv_has_bias_cases{true , false };
29
30
std::vector<std::string> conv_type_cases{" conv2d" , " depthwise_conv2d" };
30
- bool has_fp32 = false ;
31
31
bool has_int8 = false ;
32
+ bool has_weight_quant = false ;
32
33
for (auto & place : graph->valid_places ()) {
33
34
if (place.target == TARGET (kARM ) || place.target == TARGET (kHost )) {
34
- if (place.precision == PRECISION (kFloat )) {
35
- has_fp32 = true ;
36
- }
37
35
if (place.precision == PRECISION (kInt8 )) {
38
36
has_int8 = true ;
39
37
}
@@ -42,8 +40,18 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
42
40
return ;
43
41
}
44
42
}
43
+ const std::list<mir::Node>& nodes = graph->nodes ();
44
+ for (auto & node : nodes) {
45
+ if (node.IsStmt ()) {
46
+ auto * op_info = (node.stmt ())->op_info ();
47
+ if (op_info->HasAttr (" quantization_type" )) {
48
+ has_weight_quant = true ;
49
+ break ;
50
+ }
51
+ }
52
+ }
45
53
// only support arm-fp32
46
- if (has_int8 || (has_fp32 && has_int8) ) {
54
+ if (has_int8 || has_weight_quant ) {
47
55
return ;
48
56
}
49
57
// only support fp32 fusion
You can’t perform that action at this time.
0 commit comments