Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions lite/backends/arm/math/gru.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有时间新提Pr,将日期修改下:2019->2021

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "lite/backends/arm/math/sgemm.h"
#ifdef LITE_WITH_ARM
#include <arm_neon.h>
#endif

namespace paddle {
namespace lite {
namespace arm {
namespace math {

template <typename T>
struct RNNGRUValue {
const T* gate_weight;
const T* state_weight;
const T* reset_bias;
T* gate_value;
T* reset_output_value;
T* output_value;
const T* prev_out_value;
};

template <typename T>
void rnn_activation(const T* din,
T* dout,
int size,
lite_api::ActivationType act_type,
int threads) {
switch (act_type) {
case lite_api::ActivationType::kSigmoid:
act_sigmoid(din, dout, size, threads);
break;
case lite_api::ActivationType::kSigmoid_v2:
act_sigmoid(din, dout, size, threads);
break;
case lite_api::ActivationType::kTanh:
act_tanh(din, dout, size, threads);
break;
case lite_api::ActivationType::kTanh_v2:
act_tanh(din, dout, size, threads);
break;
case lite_api::ActivationType::kRelu:
act_relu(din, dout, size, threads);
break;
default:
LOG(FATAL) << "unsupport activation type:" << static_cast<int>(act_type);
break;
}
}

template <typename T>
void compute_kernel(RNNGRUValue<T> value,
int frame_size,
int batch_size,
lite_api::ActivationType active_node,
lite_api::ActivationType active_gate) {
auto value_reset_gate = value.gate_value;
auto value_update_gate = value.gate_value + frame_size;
auto value_reset_output = value.reset_output_value;
auto value_reset_bias = value.reset_bias;
auto cell_state_value = value.gate_value + 2 * frame_size;
auto value_output = value.output_value;
auto value_prev_out = value.prev_out_value;

for (int b = 0; b < batch_size; b++) {
rnn_activation(value_reset_gate,
value_reset_gate,
frame_size,
lite_api::ActivationType::kSigmoid_v2,
1);
rnn_activation(value_update_gate,
value_update_gate,
frame_size,
lite_api::ActivationType::kSigmoid_v2,
1);

for (int i = 0; i < frame_size; i++) {
value_reset_output[i] =
(value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i];
cell_state_value[i] += value_reset_output[i];
}

rnn_activation(cell_state_value,
cell_state_value,
frame_size,
lite_api::ActivationType::kTanh_v2,
1);

if (value.prev_out_value) {
for (int i = 0; i < frame_size; i++) {
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] +
value_update_gate[i] * value_prev_out[i];
}
} else {
for (int i = 0; i < frame_size; i++) {
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i];
}
}

value_reset_gate += frame_size * 3;
value_update_gate += frame_size * 3;
value_reset_output += frame_size;
cell_state_value += frame_size * 3;
value_output += frame_size;
if (value.prev_out_value) {
value_prev_out += frame_size;
}
}
}

template <>
void compute_kernel<float>(RNNGRUValue<float> value,
int frame_size,
int batch_size,
lite_api::ActivationType active_node,
lite_api::ActivationType active_gate) {
auto value_reset_gate = value.gate_value;
auto value_update_gate = value.gate_value + frame_size;
auto value_reset_output = value.reset_output_value;
auto value_reset_bias = value.reset_bias;
auto cell_state_value = value.gate_value + 2 * frame_size;
auto value_output = value.output_value;
auto value_prev_out = value.prev_out_value;
int i = 0;
float32x4_t vec_one = vdupq_n_f32(1.f);

for (int b = 0; b < batch_size; b++) {
rnn_activation(value_reset_gate,
value_reset_gate,
frame_size,
lite_api::ActivationType::kSigmoid_v2,
1);
rnn_activation(value_update_gate,
value_update_gate,
frame_size,
lite_api::ActivationType::kSigmoid_v2,
1);

for (i = 0; i + 3 < frame_size; i += 4) {
float32x4_t vec_out = vld1q_f32(value_reset_output + i);
float32x4_t vec_reset = vld1q_f32(value_reset_gate + i);
float32x4_t vec_bias = vld1q_f32(value_reset_bias + i);
vec_out = vmulq_f32(vaddq_f32(vec_out, vec_bias), vec_reset);
vst1q_f32(value_reset_output + i, vec_out);
vst1q_f32(cell_state_value + i,
vaddq_f32(vec_out, vld1q_f32(cell_state_value + i)));
}
for (; i < frame_size; i++) {
value_reset_output[i] =
(value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i];
cell_state_value[i] += value_reset_output[i];
}

rnn_activation(cell_state_value,
cell_state_value,
frame_size,
lite_api::ActivationType::kTanh_v2,
1);

if (value.prev_out_value) {
for (i = 0; i + 3 < frame_size; i += 4) {
float32x4_t vec_vug = vld1q_f32(value_update_gate + i);
float32x4_t vec_vpo = vld1q_f32(value_prev_out + i);
float32x4_t vec_csv = vld1q_f32(cell_state_value + i);
vec_vpo = vmulq_f32(vec_vug, vec_vpo);
float32x4_t vec_out =
vmlaq_f32(vec_vpo, vsubq_f32(vec_one, vec_vug), vec_csv);
vst1q_f32(value_output + i, vec_out);
}
for (; i < frame_size; i++) {
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] +
value_update_gate[i] * value_prev_out[i];
}
} else {
for (i = 0; i + 3 < frame_size; i += 4) {
float32x4_t vec_vug = vld1q_f32(value_update_gate + i);
float32x4_t vec_csv = vld1q_f32(cell_state_value + i);
float32x4_t vec_out = vmulq_f32(vsubq_f32(vec_one, vec_vug), vec_csv);
vst1q_f32(value_output + i, vec_out);
}
for (; i < frame_size; i++) {
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i];
}
}

value_reset_gate += frame_size * 3;
value_update_gate += frame_size * 3;
value_reset_output += frame_size;
cell_state_value += frame_size * 3;
value_output += frame_size;
if (value.prev_out_value) {
value_prev_out += frame_size;
}
}
}

template <typename T>
struct RnnGruUnitFunctorV2 {
static void compute(ARMContext* ctx,
RNNGRUValue<T> value,
int frame_size,
int batch_size,
lite_api::ActivationType active_node,
lite_api::ActivationType active_gate) {
if (value.prev_out_value) {
operators::ActivationParam act_param;
act_param.has_active = false;
lite::arm::math::sgemm(false,
true,
batch_size,
frame_size,
frame_size,
1.f,
value.prev_out_value,
frame_size,
value.state_weight,
frame_size,
0.f,
value.reset_output_value,
frame_size,
nullptr,
false,
act_param,
ctx);
}
compute_kernel(value, frame_size, batch_size, active_node, active_gate);
}
};

} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
1 change: 1 addition & 0 deletions lite/backends/arm/math/lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ void add_bias_rowwise(Tensor* input,
i_data += width;
}
}

void vector_dot(
float* out, const float* in, const float* v1, int size, const float* v2) {
int loop = size >> 2;
Expand Down
39 changes: 21 additions & 18 deletions lite/backends/x86/math/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,24 +150,27 @@ namespace math {
} \
}

// marco func add
ElementWiseFunc(Add) ElementWiseFuncBCast(Add)
// marco func sub
ElementWiseFunc(Sub) ElementWiseFuncBCast(Sub)
// marco func mul
ElementWiseFunc(Mul) ElementWiseFuncBCast(Mul)
// marco func max
ElementWiseFunc(Max) ElementWiseFuncBCast(Max)
// marco func min
ElementWiseFunc(Min) ElementWiseFuncBCast(Min)
// marco func div
ElementWiseFunc(Div) ElementWiseFuncBCast(Div)
// marco func floordiv
ElementWiseFunc(FloorDiv) ElementWiseFuncBCast(FloorDiv)
// marco func mod
ElementWiseFunc(Mod) ElementWiseFuncBCast(Mod)
// marco func pow
ElementWiseFunc(Pow) ElementWiseFuncBCast(Pow)
// clang-format off
ElementWiseFunc(Add)
ElementWiseFuncBCast(Add)
ElementWiseFunc(Sub)
ElementWiseFuncBCast(Sub)
ElementWiseFunc(Mul)
ElementWiseFuncBCast(Mul)
ElementWiseFunc(Max)
ElementWiseFuncBCast(Max)
ElementWiseFunc(Min)
ElementWiseFuncBCast(Min)
ElementWiseFunc(Div)
ElementWiseFuncBCast(Div)
ElementWiseFunc(FloorDiv)
ElementWiseFuncBCast(FloorDiv)
ElementWiseFunc(Mod)
ElementWiseFuncBCast(Mod)
ElementWiseFunc(Pow)
ElementWiseFuncBCast(Pow)
// clang-format on

} // namespace math
} // namespace x86
} // namespace lite
Expand Down
7 changes: 0 additions & 7 deletions lite/backends/x86/math/fill_bias_activate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ static void activate_relu_inplace(float *data, int len, float alpha, int mode) {
__m256 vec_data = _mm256_loadu_ps(data + i);
_mm256_storeu_ps(data + i, _mm256_max_ps(vec_data, vec_zero));
}
// _mm256_zeroupper();
#endif
#ifdef __SSE__
__m128 vec_zero_128 = _mm_set1_ps(0.f);
Expand All @@ -59,7 +58,6 @@ static void activate_relu_inplace(float *data, int len, float alpha, int mode) {
_mm256_storeu_ps(
data + i, _mm256_min_ps(_mm256_max_ps(vec_data, vec_zero), vec_alph));
}
// _mm256_zeroupper();
#endif
#ifdef __SSE__
__m128 vec_zero_128 = _mm_set1_ps(0.f);
Expand Down Expand Up @@ -112,7 +110,6 @@ static void activate_relu_inplace_bias(float *data,
vec_data = _mm256_add_ps(vec_bias, vec_data);
_mm256_storeu_ps(tmp_data + i, _mm256_max_ps(vec_data, vec_zero));
}
// _mm256_zeroupper();
#endif
#ifdef __SSE__
vec_bias_128 = _mm_set1_ps(bias[j]);
Expand Down Expand Up @@ -140,7 +137,6 @@ static void activate_relu_inplace_bias(float *data,
tmp_data + i,
_mm256_min_ps(_mm256_max_ps(vec_data, vec_zero), vec_alph));
}
// _mm256_zeroupper();
#endif
#ifdef __SSE__
vec_bias_128 = _mm_set1_ps(bias[j]);
Expand Down Expand Up @@ -174,7 +170,6 @@ static void activate_lrelu_inplace(float *data, int len, float alpha) {
__m256 vec_mask = _mm256_cmp_ps(vec_data, vec_zero, cmp_le_os);
_mm256_storeu_ps(data + i, _mm256_blendv_ps(vec_data, vec_lr, vec_mask));
}
// _mm256_zeroupper();
#endif
#ifdef __SSE4_1__ // blendv need 4.1
__m128 vec_zero_128 = _mm_set1_ps(0.f);
Expand Down Expand Up @@ -226,7 +221,6 @@ static void activate_lrelu_inplace_bias(float *data,
_mm256_storeu_ps(tmp_data + i,
_mm256_blendv_ps(vec_data, vec_lr, vec_mask));
}
// _mm256_zeroupper();
#endif
#ifdef __SSE4_1__
vec_bias_128 = _mm_set1_ps(bias[j]);
Expand Down Expand Up @@ -273,7 +267,6 @@ static void activate_none_inplace_bias(float *data,
vec_data = _mm256_add_ps(vec_bias, vec_data);
_mm256_storeu_ps(tmp_data + i, vec_data);
}
// _mm256_zeroupper();
#endif
#ifdef __SSE__
vec_bias_128 = _mm_set1_ps(bias[j]);
Expand Down
Loading