Skip to content

Commit 20e7c12

Browse files
[arm-fp16] add conv+hardswish support in fp16 v8 (#7403)
1 parent cd9f182 commit 20e7c12

12 files changed

+544
-84
lines changed

lite/backends/arm/math/fp16/common_preprocess.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,19 +166,26 @@ typedef __fp16 float16_t;
166166
inline void act_acquire(lite_api::ActivationType act,
167167
int &flag_act, // NOLINT
168168
float &local_alpha, // NOLINT
169-
float six,
170-
float alpha) {
169+
float &offset, // NOLINT
170+
float &threshold, // NOLINT
171+
const operators::ActivationParam act_param) {
171172
switch (act) {
172173
case lite_api::ActivationType::kRelu:
173174
flag_act = 0x01;
174175
break;
175176
case lite_api::ActivationType::kRelu6:
176177
flag_act = 0x02;
177-
local_alpha = six;
178+
local_alpha = act_param.Relu_clipped_coef;
178179
break;
179180
case lite_api::ActivationType::kLeakyRelu:
180181
flag_act = 0x03;
181-
local_alpha = alpha;
182+
local_alpha = act_param.Leaky_relu_alpha;
183+
break;
184+
case lite_api::ActivationType::kHardSwish:
185+
flag_act = 0x04;
186+
local_alpha = 1.0 / act_param.hard_swish_scale;
187+
offset = act_param.hard_swish_offset;
188+
threshold = act_param.hard_swish_threshold;
182189
break;
183190
default:
184191
break;

lite/backends/arm/math/fp16/conv3x3_winograd_fp16.cc

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,11 @@ void conv_compute_2x2_3x3_fp16(const float16_t* input,
114114
auto act_type = act_param.active_type;
115115
float local_alpha = 0.f;
116116
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
117+
float offset = 0.f;
118+
float threshold = 6.f;
117119

118120
if (act_param.has_active) {
119-
act_acquire(act_type,
120-
flag_act,
121-
local_alpha,
122-
act_param.Relu_clipped_coef,
123-
act_param.Leaky_relu_alpha);
121+
act_acquire(act_type, flag_act, local_alpha, offset, threshold, act_param);
124122
}
125123

126124
// begin compute
@@ -265,7 +263,9 @@ void conv_compute_2x2_3x3_fp16(const float16_t* input,
265263
flag_act,
266264
local_alpha,
267265
bias + ci * 8,
268-
flag_bias);
266+
flag_bias,
267+
offset,
268+
threshold);
269269
}
270270
} else {
271271
for (int ci = 0; ci < oc_8; ++ci) {
@@ -299,7 +299,9 @@ void conv_compute_2x2_3x3_fp16(const float16_t* input,
299299
flag_act,
300300
local_alpha,
301301
bias + ci * 8,
302-
flag_bias);
302+
flag_bias,
303+
offset,
304+
threshold);
303305
}
304306
}
305307
}
@@ -373,13 +375,11 @@ void conv_compute_4x4_3x3_fp16(const float16_t* input,
373375
float local_alpha = 0.f;
374376
bool flag_bias = (bias != nullptr);
375377
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
378+
float offset = 0.f;
379+
float threshold = 6.f;
376380

377381
if (act_param.has_active) {
378-
act_acquire(act_type,
379-
flag_act,
380-
local_alpha,
381-
act_param.Relu_clipped_coef,
382-
act_param.Leaky_relu_alpha);
382+
act_acquire(act_type, flag_act, local_alpha, offset, threshold, act_param);
383383
}
384384

385385
// begin compute
@@ -543,7 +543,9 @@ void conv_compute_4x4_3x3_fp16(const float16_t* input,
543543
flag_act,
544544
local_alpha,
545545
bias + ci * 8,
546-
flag_bias);
546+
flag_bias,
547+
offset,
548+
threshold);
547549
}
548550
} else {
549551
for (int ci = 0; ci < oc_8; ++ci) {
@@ -583,7 +585,9 @@ void conv_compute_4x4_3x3_fp16(const float16_t* input,
583585
flag_act,
584586
local_alpha,
585587
bias + ci * 8,
586-
flag_bias);
588+
flag_bias,
589+
offset,
590+
threshold);
587591
}
588592
}
589593
}

lite/backends/arm/math/fp16/conv3x3s1_direct_fp16.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -494,12 +494,11 @@ void conv_3x3s1_direct_fp16(const float16_t* i_data,
494494
float alpha = 0.f;
495495
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
496496

497+
float offset = 0.f;
498+
float threshold = 6.f;
499+
497500
if (act_param.has_active) {
498-
act_acquire(act_type,
499-
flag_act,
500-
alpha,
501-
act_param.Relu_clipped_coef,
502-
act_param.Leaky_relu_alpha);
501+
act_acquire(act_type, flag_act, alpha, offset, threshold, act_param);
503502
}
504503

505504
for (int n = 0; n < bs; ++n) {
@@ -617,7 +616,9 @@ void conv_3x3s1_direct_fp16(const float16_t* i_data,
617616
flag_act,
618617
alpha,
619618
bias_ptr,
620-
flag_bias);
619+
flag_bias,
620+
offset,
621+
threshold);
621622
}
622623
}
623624
}

lite/backends/arm/math/fp16/conv3x3s2_direct_fp16.cc

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -626,16 +626,13 @@ void conv_3x3s2_direct_fp16(const float16_t* i_data,
626626
auto act_type = act_param.active_type;
627627
bool flag_bias = param.bias != nullptr;
628628
float alpha = 0.f;
629-
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
629+
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3 hardswish:4
630+
float offset = 0.f;
631+
float threshold = 6.f;
630632

631633
if (act_param.has_active) {
632-
act_acquire(act_type,
633-
flag_act,
634-
alpha,
635-
act_param.Relu_clipped_coef,
636-
act_param.Leaky_relu_alpha);
634+
act_acquire(act_type, flag_act, alpha, offset, threshold, act_param);
637635
}
638-
639636
for (int n = 0; n < bs; ++n) {
640637
const float16_t* din_batch = i_data + n * ic * size_in_channel;
641638
float16_t* dout_batch = o_data + n * oc * size_out_channel;
@@ -734,7 +731,9 @@ void conv_3x3s2_direct_fp16(const float16_t* i_data,
734731
flag_act,
735732
alpha,
736733
bias_ptr,
737-
flag_bias);
734+
flag_bias,
735+
offset,
736+
threshold);
738737
}
739738
}
740739
}
@@ -787,14 +786,12 @@ void conv_3x3s2_direct_fp16_c3(const float16_t* i_data,
787786
auto act_type = act_param.active_type;
788787
bool flag_bias = param.bias != nullptr;
789788
float alpha = 0.f;
790-
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
789+
int flag_act = 0x00;
790+
float offset = 0.f;
791+
float threshold = 6.f;
791792

792793
if (act_param.has_active) {
793-
act_acquire(act_type,
794-
flag_act,
795-
alpha,
796-
act_param.Relu_clipped_coef,
797-
act_param.Leaky_relu_alpha);
794+
act_acquire(act_type, flag_act, alpha, offset, threshold, act_param);
798795
}
799796

800797
for (int n = 0; n < bs; ++n) {
@@ -917,7 +914,9 @@ void conv_3x3s2_direct_fp16_c3(const float16_t* i_data,
917914
flag_act,
918915
alpha,
919916
bias_ptr,
920-
flag_bias);
917+
flag_bias,
918+
offset,
919+
threshold);
921920
}
922921
}
923922
}

0 commit comments

Comments
 (0)