Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lite/backends/arm/math/fp16/conv_impl_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@ void conv1x1s1_gemm_fp16(CONV_PARAM(float16_t)) {
}
}

gemv_fp16(weights_group,
din_group,
gemv_fp16(din_group,
weights_group,
dout_group,
true,
n,
Expand Down Expand Up @@ -453,8 +453,8 @@ void conv_im2col_gemm_fp16(CONV_PARAM(float16_t)) {
}
}

gemv_fp16(weights_group,
dB,
gemv_fp16(dB,
weights_group,
dout_group,
true,
n,
Expand Down
10 changes: 3 additions & 7 deletions lite/backends/arm/math/fp16/gemv_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ void gemv_fp16_trans(const float16_t *A,
ARMContext *ctx) {
int Nup = (N + 7) / 8 * 8;
int Mup = (M + 7) / 8 * 8;
auto size = (Mup * 3 + Nup);
auto size = (Mup * 2 + Nup);
ctx->ExtendWorkspace(size * sizeof(float16_t));
auto ptr_zero = ctx->workspace_data<float16_t>();
memset(ptr_zero, 0, Mup * sizeof(float16_t));
Expand All @@ -289,12 +289,8 @@ void gemv_fp16_trans(const float16_t *A,
memset(bias_ptr, 0, Mup * sizeof(float16_t));
}
float16_t *ptr_w = bias_ptr + Mup;
lite::TargetWrapperHost::MemcpySync(ptr_w, A, N * sizeof(float16_t));
lite::TargetWrapperHost::MemcpySync(ptr_w, x, N * sizeof(float16_t));
memset(ptr_w + N, 0, (Nup - N) * sizeof(float16_t));
float16_t *data_in = ptr_w + Nup;
lite::TargetWrapperHost::MemcpySync(
data_in, x + (N - 1) * M, M * sizeof(float16_t));
memset(data_in + M, 0, (Mup - M) * sizeof(float16_t));
memset(y, 0, M * sizeof(float16_t));
float16_t local_alpha = 0.f;
float16_t offset = 0.f;
Expand All @@ -317,7 +313,7 @@ void gemv_fp16_trans(const float16_t *A,
int y_index = j * 8;
const float16_t *ptr_in = ptr_w + y_index;
const float16_t *inptr_row[8];
inptr_row[0] = x + y_index * M;
inptr_row[0] = A + y_index * M;
for (int i = 1; i < 8; i++) {
inptr_row[i] = inptr_row[i - 1] + M;
}
Expand Down
4 changes: 2 additions & 2 deletions lite/tests/math/gemv_fp16_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ bool test_sgemv_fp16(bool tra,
dc_basic,
static_cast<float16_t>(1.f),
static_cast<float16_t>(0.f),
false,
tra,
has_bias,
flag_act,
alpha);
Expand Down Expand Up @@ -211,7 +211,7 @@ TEST(TestLiteGemvFP16, gemv_fp16) {
LOG(INFO) << "run basic sgemm test";
for (auto& m : {3, 8, 32, 397}) {
for (auto& n : {3, 13, 141, 512, 789}) {
for (auto& tra : {false}) {
for (auto& tra : {false, true}) {
for (auto& has_bias : {false, true}) {
for (auto& flag_act : {0, 1}) {
for (auto& th : {1, 2, 4}) {
Expand Down