Skip to content

Commit 7a610b3

Browse files
authored
Gemv fp16 fix (#9331)
* fixed fp16 gemv compute error while m = 1
1 parent b5e5795 commit 7a610b3

File tree

3 files changed

+9
-13
lines changed

3 files changed

+9
-13
lines changed

lite/backends/arm/math/fp16/conv_impl_fp16.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ void conv1x1s1_gemm_fp16(CONV_PARAM(float16_t)) {
339339
}
340340
}
341341

342-
gemv_fp16(weights_group,
343-
din_group,
342+
gemv_fp16(din_group,
343+
weights_group,
344344
dout_group,
345345
true,
346346
n,
@@ -454,8 +454,8 @@ void conv_im2col_gemm_fp16(CONV_PARAM(float16_t)) {
454454
}
455455
}
456456

457-
gemv_fp16(weights_group,
458-
dB,
457+
gemv_fp16(dB,
458+
weights_group,
459459
dout_group,
460460
true,
461461
n,

lite/backends/arm/math/fp16/gemv_fp16.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ void gemv_fp16_trans(const float16_t *A,
277277
ARMContext *ctx) {
278278
int Nup = (N + 7) / 8 * 8;
279279
int Mup = (M + 7) / 8 * 8;
280-
auto size = (Mup * 3 + Nup);
280+
auto size = (Mup * 2 + Nup);
281281
ctx->ExtendWorkspace(size * sizeof(float16_t));
282282
auto ptr_zero = ctx->workspace_data<float16_t>();
283283
memset(ptr_zero, 0, Mup * sizeof(float16_t));
@@ -289,12 +289,8 @@ void gemv_fp16_trans(const float16_t *A,
289289
memset(bias_ptr, 0, Mup * sizeof(float16_t));
290290
}
291291
float16_t *ptr_w = bias_ptr + Mup;
292-
lite::TargetWrapperHost::MemcpySync(ptr_w, A, N * sizeof(float16_t));
292+
lite::TargetWrapperHost::MemcpySync(ptr_w, x, N * sizeof(float16_t));
293293
memset(ptr_w + N, 0, (Nup - N) * sizeof(float16_t));
294-
float16_t *data_in = ptr_w + Nup;
295-
lite::TargetWrapperHost::MemcpySync(
296-
data_in, x + (N - 1) * M, M * sizeof(float16_t));
297-
memset(data_in + M, 0, (Mup - M) * sizeof(float16_t));
298294
memset(y, 0, M * sizeof(float16_t));
299295
float16_t local_alpha = 0.f;
300296
float16_t offset = 0.f;
@@ -317,7 +313,7 @@ void gemv_fp16_trans(const float16_t *A,
317313
int y_index = j * 8;
318314
const float16_t *ptr_in = ptr_w + y_index;
319315
const float16_t *inptr_row[8];
320-
inptr_row[0] = x + y_index * M;
316+
inptr_row[0] = A + y_index * M;
321317
for (int i = 1; i < 8; i++) {
322318
inptr_row[i] = inptr_row[i - 1] + M;
323319
}

lite/tests/math/gemv_fp16_compute_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ bool test_sgemv_fp16(bool tra,
136136
dc_basic,
137137
static_cast<float16_t>(1.f),
138138
static_cast<float16_t>(0.f),
139-
false,
139+
tra,
140140
has_bias,
141141
flag_act,
142142
alpha);
@@ -211,7 +211,7 @@ TEST(TestLiteGemvFP16, gemv_fp16) {
211211
LOG(INFO) << "run basic sgemm test";
212212
for (auto& m : {3, 8, 32, 397}) {
213213
for (auto& n : {3, 13, 141, 512, 789}) {
214-
for (auto& tra : {false}) {
214+
for (auto& tra : {false, true}) {
215215
for (auto& has_bias : {false, true}) {
216216
for (auto& flag_act : {0, 1}) {
217217
for (auto& th : {1, 2, 4}) {

0 commit comments

Comments
 (0)