|
| 1 | +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#pragma once |
| 16 | + |
| 17 | +#include <arm_neon.h> |
| 18 | +#include <string> |
| 19 | +#include "lite/backends/arm/math/activation.h" |
| 20 | +#include "lite/core/tensor.h" |
| 21 | +#include "lite/utils/logging.h" |
| 22 | +namespace paddle { |
| 23 | +namespace lite { |
| 24 | +namespace arm { |
| 25 | +namespace math { |
| 26 | + |
| 27 | +void add_bias_rowwise(Tensor* input, |
| 28 | + const Tensor* bias, |
| 29 | + int start_w, |
| 30 | + int end_w); |
| 31 | + |
| 32 | +inline float* row_offset(Tensor& input, int start) { // NOLINT |
| 33 | + auto in_dim = input.dims(); |
| 34 | + int width = input.numel() / in_dim[0]; |
| 35 | + int offset = start < in_dim[0] ? start * width : input.numel(); |
| 36 | + return input.mutable_data<float>() + offset; |
| 37 | +} |
| 38 | +template <class T> |
| 39 | +struct LstmMetaValue { |
| 40 | + T* gate_value; |
| 41 | + T* prev_state_value; |
| 42 | + T* state_value; |
| 43 | + T* state_active_value; |
| 44 | + T* output_value; |
| 45 | + T* check_ig; |
| 46 | + T* check_fg; |
| 47 | + T* check_og; |
| 48 | +}; |
| 49 | + |
| 50 | +template <typename T> |
| 51 | +void activation( |
| 52 | + const T* din, T* dout, int size, std::string act_str, int threads) { |
| 53 | + if (act_str == "sigmoid") { |
| 54 | + act_sigmoid(din, dout, size, threads); |
| 55 | + } else if (act_str == "tanh") { |
| 56 | + act_tanh(din, dout, size, threads); |
| 57 | + } else if (act_str == "relu") { |
| 58 | + act_relu(din, dout, size, threads); |
| 59 | + } else { |
| 60 | + LOG(FATAL) << "unsupport activation " << act_str; |
| 61 | + } |
| 62 | +} |
| 63 | + |
| 64 | +void vector_dot(float* out, |
| 65 | + const float* in, |
| 66 | + const float* v1, |
| 67 | + int size, |
| 68 | + const float* v2 = nullptr); |
| 69 | + |
| 70 | +template <typename T> |
| 71 | +struct LstmUnitFunctor { |
| 72 | + static void compute(LstmMetaValue<T> value, |
| 73 | + int frame_size, |
| 74 | + int batch_size, |
| 75 | + T cell_clip, |
| 76 | + std::string gate_act, |
| 77 | + std::string cell_act, |
| 78 | + std::string cand_act, |
| 79 | + int threads) { |
| 80 | + for (int b = 0; b < batch_size; ++b) { |
| 81 | + const int temp_len = frame_size; |
| 82 | + float zero_ptr[temp_len]; // NOLINT |
| 83 | + memset(zero_ptr, 0, sizeof(float) * temp_len); |
| 84 | + |
| 85 | + T* value_in = value.gate_value; |
| 86 | + T* value_ig = value_in + frame_size; |
| 87 | + T* value_fg = value_ig + frame_size; |
| 88 | + T* value_og = value_fg + frame_size; |
| 89 | + T* state = value.state_value; |
| 90 | + T* state_act = value.state_active_value; |
| 91 | + T* output = value.output_value; |
| 92 | + |
| 93 | + T* check_i = value.check_ig ? value.check_ig : zero_ptr; |
| 94 | + T* check_f = value.check_fg ? value.check_fg : zero_ptr; |
| 95 | + T* check_o = value.check_og ? value.check_og : zero_ptr; |
| 96 | + T* prev_state = |
| 97 | + value.prev_state_value ? value.prev_state_value : zero_ptr; |
| 98 | + |
| 99 | + activation(value_in, value_in, frame_size, gate_act, threads); |
| 100 | + vector_dot(value_ig, value_ig, prev_state, frame_size, check_i); |
| 101 | + vector_dot(value_fg, value_fg, prev_state, frame_size, check_f); |
| 102 | + activation(value_ig, value_ig, frame_size, cell_act, threads); |
| 103 | + activation(value_fg, value_fg, frame_size, cell_act, threads); |
| 104 | + vector_dot(state, value_in, value_ig, frame_size); |
| 105 | + vector_dot(state, state, prev_state, frame_size, value_fg); |
| 106 | + |
| 107 | + for (int i = 0; i < frame_size; ++i) { |
| 108 | + if (cell_clip > 0.0) { |
| 109 | + if (state[i] < -1.0 * cell_clip) { |
| 110 | + state[i] = -1.0 * cell_clip; |
| 111 | + } |
| 112 | + if (state[i] > cell_clip) { |
| 113 | + state[i] = cell_clip; |
| 114 | + } |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + vector_dot(value_og, value_og, state, frame_size, check_o); |
| 119 | + activation(value_og, value_og, frame_size, cell_act, threads); |
| 120 | + activation(state, state_act, frame_size, cand_act, threads); |
| 121 | + vector_dot(value.output_value, value_og, state_act, frame_size); |
| 122 | + |
| 123 | + value.gate_value += frame_size * 4; |
| 124 | + value.state_value += frame_size; |
| 125 | + value.state_active_value += frame_size; |
| 126 | + value.output_value += frame_size; |
| 127 | + if (value.prev_state_value) { |
| 128 | + value.prev_state_value += frame_size; |
| 129 | + } |
| 130 | + } |
| 131 | + } |
| 132 | +}; |
| 133 | + |
| 134 | +} // namespace math |
| 135 | +} // namespace arm |
| 136 | +} // namespace lite |
| 137 | +} // namespace paddle |
0 commit comments