Skip to content

Commit 292f52c

Browse files
authored
change box_coder. test=develop (#4107)
1 parent 17f0063 commit 292f52c

File tree

3 files changed

+228
-247
lines changed

3 files changed

+228
-247
lines changed

lite/backends/arm/math/box_coder.cc

Lines changed: 206 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -20,69 +20,214 @@ namespace lite {
2020
namespace arm {
2121
namespace math {
2222

23-
void box_coder(lite::Tensor* proposals,
24-
const lite::Tensor* anchors,
25-
const lite::Tensor* variances,
26-
const lite::Tensor* bbox_deltas,
27-
const std::string code_type,
28-
bool box_normalized,
29-
int axis) {
30-
if (code_type == "decode_center_size") {
31-
float normalized = !box_normalized ? 1.f : 0;
32-
33-
const float* anchor_data = anchors->data<float>();
34-
const float* bbox_deltas_data = bbox_deltas->data<float>();
35-
float* proposals_data = proposals->mutable_data<float>();
36-
const float* variances_data = variances->data<float>();
37-
38-
int N = bbox_deltas->dims()[0];
39-
int M = bbox_deltas->dims()[1];
40-
int len = bbox_deltas->dims()[2];
41-
42-
for (int64_t row_id = 0; row_id < N; ++row_id) {
43-
for (int64_t col_id = 0; col_id < M; ++col_id) {
44-
size_t offset = row_id * M * len + col_id * len;
45-
int prior_box_offset = axis == 0 ? col_id * len : row_id * len;
46-
int var_offset = axis == 0 ? col_id * len : row_id * len;
47-
48-
auto anchor_data_tmp = anchor_data + prior_box_offset;
49-
auto bbox_deltas_data_tmp = bbox_deltas_data + offset;
50-
auto proposals_data_tmp = proposals_data + offset;
51-
52-
auto anchor_width =
53-
anchor_data_tmp[2] - anchor_data_tmp[0] + normalized;
54-
auto anchor_height =
55-
anchor_data_tmp[3] - anchor_data_tmp[1] + normalized;
56-
auto anchor_center_x = anchor_data_tmp[0] + 0.5 * anchor_width;
57-
auto anchor_center_y = anchor_data_tmp[1] + 0.5 * anchor_height;
58-
59-
float bbox_center_x = 0, bbox_center_y = 0;
60-
float bbox_width = 0, bbox_height = 0;
61-
62-
auto variances_data_tmp = variances_data + var_offset;
63-
64-
bbox_center_x =
65-
variances_data_tmp[0] * bbox_deltas_data_tmp[0] * anchor_width +
66-
anchor_center_x;
67-
bbox_center_y =
68-
variances_data_tmp[1] * bbox_deltas_data_tmp[1] * anchor_height +
69-
anchor_center_y;
70-
bbox_width = std::exp(variances_data_tmp[2] * bbox_deltas_data_tmp[2]) *
71-
anchor_width;
72-
bbox_height =
73-
std::exp(variances_data_tmp[3] * bbox_deltas_data_tmp[3]) *
74-
anchor_height;
75-
76-
proposals_data_tmp[0] = bbox_center_x - bbox_width / 2;
77-
proposals_data_tmp[1] = bbox_center_y - bbox_height / 2;
78-
proposals_data_tmp[2] = bbox_center_x + bbox_width / 2 - normalized;
79-
proposals_data_tmp[3] = bbox_center_y + bbox_height / 2 - normalized;
80-
}
23+
void decode_bbox_center_variance_kernel(const int batch_num,
24+
const float* loc_data,
25+
const float* prior_data,
26+
const float* variance,
27+
const int num_priors,
28+
float* bbox_data) {
29+
int cnt = num_priors / 4;
30+
//! vprior 0: xmin, 1: ymin, 2: xmax, 3: ymax
31+
//! vloc 0: xmin, 1: ymin, 2: xmax, 3: ymax
32+
//! vvar
33+
float32x4_t vhalf = vdupq_n_f32(0.5f);
34+
35+
int len_batch = num_priors * 4;
36+
37+
for (int n = 0; n < batch_num; ++n) {
38+
const float* ptr_loc_batch = loc_data + n * len_batch;
39+
float* ptr_bbox_batch = bbox_data + n * len_batch;
40+
41+
#pragma omp parallel for
42+
for (int i = 0; i < cnt; ++i) {
43+
int idx = i * 16;
44+
const float* ptr_loc = ptr_loc_batch + idx;
45+
const float* ptr_prior = prior_data + idx;
46+
float* ptr_bbox = ptr_bbox_batch + idx;
47+
48+
float32x4x4_t vprior = vld4q_f32(ptr_prior);
49+
float32x4x4_t vloc = vld4q_f32(ptr_loc);
50+
float32x4_t vprior_width = vsubq_f32(vprior.val[2], vprior.val[0]);
51+
float32x4_t vprior_height = vsubq_f32(vprior.val[3], vprior.val[1]);
52+
float32x4_t vprior_cx =
53+
vmulq_f32(vaddq_f32(vprior.val[0], vprior.val[2]), vhalf);
54+
float32x4_t vprior_cy =
55+
vmulq_f32(vaddq_f32(vprior.val[1], vprior.val[3]), vhalf);
56+
57+
float32x4_t vdec_bbx_cx =
58+
vaddq_f32(vmulq_f32(vloc.val[0], vprior_width), vprior_cx);
59+
float32x4_t vdec_bbx_cy =
60+
vaddq_f32(vmulq_f32(vloc.val[1], vprior_height), vprior_cy);
61+
float32x4_t vdec_bbx_w = exp_ps(vloc.val[2]);
62+
float32x4_t vdec_bbx_h = exp_ps(vloc.val[3]);
63+
vprior_width = vmulq_f32(vprior_width, vhalf);
64+
vprior_height = vmulq_f32(vprior_height, vhalf);
65+
vdec_bbx_w = vmulq_f32(vdec_bbx_w, vprior_width);
66+
vdec_bbx_h = vmulq_f32(vdec_bbx_h, vprior_height);
67+
68+
vloc.val[0] = vsubq_f32(vdec_bbx_cx, vdec_bbx_w);
69+
vloc.val[1] = vsubq_f32(vdec_bbx_cy, vdec_bbx_h);
70+
vloc.val[2] = vaddq_f32(vdec_bbx_cx, vdec_bbx_w);
71+
vloc.val[3] = vaddq_f32(vdec_bbx_cy, vdec_bbx_h);
72+
73+
vst4q_f32(ptr_bbox, vloc);
8174
}
82-
} else if (code_type == "encode_center_size") {
83-
LOG(FATAL) << "not implemented type: " << code_type;
75+
#pragma omp parallel for
76+
for (int i = cnt * 4; i < num_priors; i++) {
77+
int idx = i * 4;
78+
float p_xmin = prior_data[idx];
79+
float p_ymin = prior_data[idx + 1];
80+
float p_xmax = prior_data[idx + 2];
81+
float p_ymax = prior_data[idx + 3];
82+
float prior_width = p_xmax - p_xmin;
83+
float prior_height = p_ymax - p_ymin;
84+
float prior_center_x = (p_xmin + p_xmax) / 2.f;
85+
float prior_center_y = (p_ymin + p_ymax) / 2.f;
86+
87+
float xmin = ptr_loc_batch[idx];
88+
float ymin = ptr_loc_batch[idx + 1];
89+
float xmax = ptr_loc_batch[idx + 2];
90+
float ymax = ptr_loc_batch[idx + 3];
91+
92+
//! variance is encoded in target, we simply need to retore the offset
93+
//! predictions.
94+
float decode_bbox_center_x = xmin * prior_width + prior_center_x;
95+
float decode_bbox_center_y = ymin * prior_height + prior_center_y;
96+
float decode_bbox_width = expf(xmax) * prior_width;
97+
float decode_bbox_height = expf(ymax) * prior_height;
98+
99+
ptr_bbox_batch[idx] = decode_bbox_center_x - decode_bbox_width / 2.f;
100+
ptr_bbox_batch[idx + 1] = decode_bbox_center_y - decode_bbox_height / 2.f;
101+
ptr_bbox_batch[idx + 2] = decode_bbox_center_x + decode_bbox_width / 2.f;
102+
ptr_bbox_batch[idx + 3] = decode_bbox_center_y + decode_bbox_height / 2.f;
103+
}
104+
}
105+
}
106+
107+
void decode_bbox_center_no_variance_kernel(const int batch_num,
108+
const float* loc_data,
109+
const float* prior_data,
110+
const float* variance,
111+
const int num_priors,
112+
const bool normalized,
113+
float* bbox_data) {
114+
int cnt = num_priors / 4;
115+
//! vprior 0: xmin, 1: ymin, 2: xmax, 3: ymax
116+
//! vloc 0: xmin, 1: ymin, 2: xmax, 3: ymax
117+
//! vvar
118+
float32x4_t vhalf = vdupq_n_f32(0.5f);
119+
float norm_value = (normalized == false);
120+
float32x4_t vnormalized = vdupq_n_f32(norm_value);
121+
int len_batch = num_priors * 4;
122+
123+
for (int n = 0; n < batch_num; ++n) {
124+
const float* ptr_loc_batch = loc_data + n * len_batch;
125+
float* ptr_bbox_batch = bbox_data + n * len_batch;
126+
127+
#pragma omp parallel for
128+
for (int i = 0; i < cnt; ++i) {
129+
int idx = i * 16;
130+
131+
const float* ptr_loc = ptr_loc_batch + idx;
132+
const float* ptr_prior = prior_data + idx;
133+
const float* ptr_var = variance + idx;
134+
float* ptr_bbox = ptr_bbox_batch + idx;
135+
136+
float32x4x4_t vprior = vld4q_f32(ptr_prior);
137+
float32x4x4_t vloc = vld4q_f32(ptr_loc);
138+
float32x4x4_t vvar = vld4q_f32(ptr_var);
139+
float32x4_t vprior_width1 = vsubq_f32(vprior.val[2], vprior.val[0]);
140+
float32x4_t vprior_height1 = vsubq_f32(vprior.val[3], vprior.val[1]);
141+
float32x4_t vprior_width = vaddq_f32(vprior_width1, vnormalized);
142+
float32x4_t vprior_height = vaddq_f32(vprior_height1, vnormalized);
143+
144+
float32x4_t vprior_cx =
145+
vaddq_f32(vprior.val[0], vmulq_f32(vprior_width, vhalf));
146+
float32x4_t vprior_cy =
147+
vaddq_f32(vprior.val[1], vmulq_f32(vprior_height, vhalf));
148+
149+
vloc.val[0] = vmulq_f32(vloc.val[0], vvar.val[0]);
150+
vloc.val[1] = vmulq_f32(vloc.val[1], vvar.val[1]);
151+
vloc.val[2] = vmulq_f32(vloc.val[2], vvar.val[2]);
152+
vloc.val[3] = vmulq_f32(vloc.val[3], vvar.val[3]);
153+
154+
float32x4_t vdec_bbx_cx =
155+
vaddq_f32(vmulq_f32(vloc.val[0], vprior_width), vprior_cx);
156+
float32x4_t vdec_bbx_cy =
157+
vaddq_f32(vmulq_f32(vloc.val[1], vprior_height), vprior_cy);
158+
float32x4_t vdec_bbx_w = exp_ps(vloc.val[2]);
159+
float32x4_t vdec_bbx_h = exp_ps(vloc.val[3]);
160+
vprior_width = vmulq_f32(vprior_width, vhalf);
161+
vprior_height = vmulq_f32(vprior_height, vhalf);
162+
vdec_bbx_w = vmulq_f32(vdec_bbx_w, vprior_width);
163+
vdec_bbx_h = vmulq_f32(vdec_bbx_h, vprior_height);
164+
165+
vloc.val[0] = vsubq_f32(vdec_bbx_cx, vdec_bbx_w);
166+
vloc.val[1] = vsubq_f32(vdec_bbx_cy, vdec_bbx_h);
167+
vloc.val[2] = vaddq_f32(vdec_bbx_cx, vsubq_f32(vdec_bbx_w, vnormalized));
168+
vloc.val[3] = vaddq_f32(vdec_bbx_cy, vsubq_f32(vdec_bbx_h, vnormalized));
169+
170+
vst4q_f32(ptr_bbox, vloc);
171+
}
172+
173+
#pragma omp parallel for
174+
for (int i = cnt * 4; i < num_priors; i++) {
175+
int idx = i * 4;
176+
float p_xmin = prior_data[idx];
177+
float p_ymin = prior_data[idx + 1];
178+
float p_xmax = prior_data[idx + 2];
179+
float p_ymax = prior_data[idx + 3];
180+
float prior_width = p_xmax - p_xmin + norm_value;
181+
float prior_height = p_ymax - p_ymin + norm_value;
182+
float prior_center_x = p_xmin + prior_width / 2.f;
183+
float prior_center_y = p_ymin + prior_height / 2.f;
184+
185+
float xmin = ptr_loc_batch[idx];
186+
float ymin = ptr_loc_batch[idx + 1];
187+
float xmax = ptr_loc_batch[idx + 2];
188+
float ymax = ptr_loc_batch[idx + 3];
189+
190+
//! variance is encoded in target, we simply need to retore the offset
191+
//! predictions.
192+
float decode_bbox_center_x =
193+
variance[idx] * xmin * prior_width + prior_center_x;
194+
float decode_bbox_center_y =
195+
variance[idx + 1] * ymin * prior_height + prior_center_y;
196+
float decode_bbox_width = expf(variance[idx + 2] * xmax) * prior_width;
197+
float decode_bbox_height = expf(variance[idx + 3] * ymax) * prior_height;
198+
199+
ptr_bbox_batch[idx] = decode_bbox_center_x - decode_bbox_width / 2.f;
200+
ptr_bbox_batch[idx + 1] = decode_bbox_center_y - decode_bbox_height / 2.f;
201+
ptr_bbox_batch[idx + 2] =
202+
decode_bbox_center_x + decode_bbox_width / 2.f - norm_value;
203+
ptr_bbox_batch[idx + 3] =
204+
decode_bbox_center_y + decode_bbox_height / 2.f - norm_value;
205+
}
206+
}
207+
}
208+
209+
void decode_bboxes(const int batch_num,
210+
const float* loc_data,
211+
const float* prior_data,
212+
const float* variance_data,
213+
const std::string code_type,
214+
const bool normalized,
215+
const int num_priors,
216+
float* bbox_data) {
217+
if (code_type == "encode_center_size") {
218+
decode_bbox_center_variance_kernel(
219+
batch_num, loc_data, prior_data, variance_data, num_priors, bbox_data);
220+
221+
} else if (code_type == "decode_center_size") {
222+
decode_bbox_center_no_variance_kernel(batch_num,
223+
loc_data,
224+
prior_data,
225+
variance_data,
226+
num_priors,
227+
normalized,
228+
bbox_data);
84229
} else {
85-
LOG(FATAL) << "not supported type: " << code_type;
230+
LOG(FATAL) << "box_coder don't support this code_type: " << code_type;
86231
}
87232
}
88233

lite/backends/arm/math/box_coder.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ namespace lite {
2222
namespace arm {
2323
namespace math {
2424

25-
void box_coder(lite::Tensor* proposals,
26-
const lite::Tensor* anchors,
27-
const lite::Tensor* variances,
28-
const lite::Tensor* bbox_deltas,
29-
const std::string code_type,
30-
bool box_normalized,
31-
int axis);
25+
void decode_bboxes(const int batch_num,
26+
const float* loc_data,
27+
const float* prior_data,
28+
const float* variance_data,
29+
const std::string code_type,
30+
const bool normalized,
31+
const int num_priors,
32+
float* bbox_data);
3233

3334
} // namespace math
3435
} // namespace arm

0 commit comments

Comments
 (0)