Skip to content

Commit 9fcb96d

Browse files
MyPandaShaoxiangroot
authored andcommitted
* feat: add lstm op && kernel test=develop
1 parent b3a8fcb commit 9fcb96d

File tree

10 files changed

+648
-0
lines changed

10 files changed

+648
-0
lines changed

lite/backends/arm/math/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,5 +123,6 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
123123
anchor_generator.cc
124124
split_merge_lod_tenosr.cc
125125
reduce_prod.cc
126+
lstm.cc
126127
DEPS ${lite_kernel_deps} context tensor)
127128
endif()

lite/backends/arm/math/lstm.cc

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/backends/arm/math/lstm.h"
16+
#include "lite/backends/arm/math/funcs.h"
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace arm {
21+
namespace math {
22+
23+
void add_bias_rowwise(Tensor* input,
24+
const Tensor* bias,
25+
int start_w,
26+
int end_w) {
27+
auto in_dim = input->dims();
28+
int width = input->numel() / in_dim[0];
29+
int w_adds = width < end_w ? width : end_w;
30+
float* i_data = input->mutable_data<float>();
31+
const float* b_data = bias->data<float>();
32+
for (int i = 0; i < in_dim[0]; ++i) {
33+
for (int w = start_w; w < w_adds; ++w) {
34+
i_data[w] += b_data[w];
35+
}
36+
}
37+
}
38+
void vector_dot(
39+
float* out, const float* in, const float* v1, int size, const float* v2) {
40+
int loop = size >> 2;
41+
int remain = size & 3;
42+
const float* in_ptr = in;
43+
float* out_ptr = out;
44+
const float* v1_ptr = v1;
45+
const float* v2_ptr = v2;
46+
for (int i = 0; i < loop; ++i) {
47+
float32x4_t in = vld1q_f32(in_ptr);
48+
float32x4_t data1 = vld1q_f32(v1_ptr);
49+
if (!v2) {
50+
// in_out * v1
51+
float32x4_t out = vmulq_f32(in, data1);
52+
vst1q_f32(out_ptr, out);
53+
in_ptr += 4;
54+
v1_ptr += 4;
55+
out_ptr += 4;
56+
} else {
57+
// in_out + v1 * v2
58+
float32x4_t data2 = vld1q_f32(v2_ptr);
59+
float32x4_t out = vmlaq_f32(in, data1, data2);
60+
vst1q_f32(out_ptr, out);
61+
in_ptr += 4;
62+
v1_ptr += 4;
63+
out_ptr += 4;
64+
v2_ptr += 4;
65+
}
66+
}
67+
for (int i = 0; i < remain; ++i) {
68+
if (!v2) {
69+
out_ptr[i] = in_ptr[i] * v1_ptr[i];
70+
++out_ptr;
71+
++in_ptr;
72+
++v1_ptr;
73+
} else {
74+
out_ptr[i] = in_ptr[i] + v1_ptr[i] * v2_ptr[i];
75+
++out_ptr;
76+
++in_ptr;
77+
++v1_ptr;
78+
++v2_ptr;
79+
}
80+
}
81+
}
82+
83+
} // namespace math
84+
} // namespace arm
85+
} // namespace lite
86+
} // namespace paddle

lite/backends/arm/math/lstm.h

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <arm_neon.h>
18+
#include <string>
19+
#include "lite/backends/arm/math/activation.h"
20+
#include "lite/core/tensor.h"
21+
#include "lite/utils/logging.h"
22+
namespace paddle {
23+
namespace lite {
24+
namespace arm {
25+
namespace math {
26+
27+
void add_bias_rowwise(Tensor* input,
28+
const Tensor* bias,
29+
int start_w,
30+
int end_w);
31+
32+
inline float* row_offset(Tensor& input, int start) { // NOLINT
33+
auto in_dim = input.dims();
34+
int width = input.numel() / in_dim[0];
35+
int offset = start < in_dim[0] ? start * width : input.numel();
36+
return input.mutable_data<float>() + offset;
37+
}
38+
template <class T>
39+
struct LstmMetaValue {
40+
T* gate_value;
41+
T* prev_state_value;
42+
T* state_value;
43+
T* state_active_value;
44+
T* output_value;
45+
T* check_ig;
46+
T* check_fg;
47+
T* check_og;
48+
};
49+
50+
template <typename T>
51+
void activation(
52+
const T* din, T* dout, int size, std::string act_str, int threads) {
53+
if (act_str == "sigmoid") {
54+
act_sigmoid(din, dout, size, threads);
55+
} else if (act_str == "tanh") {
56+
act_tanh(din, dout, size, threads);
57+
} else if (act_str == "relu") {
58+
act_relu(din, dout, size, threads);
59+
} else {
60+
LOG(FATAL) << "unsupport activation " << act_str;
61+
}
62+
}
63+
64+
void vector_dot(float* out,
65+
const float* in,
66+
const float* v1,
67+
int size,
68+
const float* v2 = nullptr);
69+
70+
template <typename T>
71+
struct LstmUnitFunctor {
72+
static void compute(LstmMetaValue<T> value,
73+
int frame_size,
74+
int batch_size,
75+
T cell_clip,
76+
std::string gate_act,
77+
std::string cell_act,
78+
std::string cand_act,
79+
int threads) {
80+
for (int b = 0; b < batch_size; ++b) {
81+
const int temp_len = frame_size;
82+
float zero_ptr[temp_len]; // NOLINT
83+
memset(zero_ptr, 0, sizeof(float) * temp_len);
84+
85+
T* value_in = value.gate_value;
86+
T* value_ig = value_in + frame_size;
87+
T* value_fg = value_ig + frame_size;
88+
T* value_og = value_fg + frame_size;
89+
T* state = value.state_value;
90+
T* state_act = value.state_active_value;
91+
T* output = value.output_value;
92+
93+
T* check_i = value.check_ig ? value.check_ig : zero_ptr;
94+
T* check_f = value.check_fg ? value.check_fg : zero_ptr;
95+
T* check_o = value.check_og ? value.check_og : zero_ptr;
96+
T* prev_state =
97+
value.prev_state_value ? value.prev_state_value : zero_ptr;
98+
99+
activation(value_in, value_in, frame_size, gate_act, threads);
100+
vector_dot(value_ig, value_ig, prev_state, frame_size, check_i);
101+
vector_dot(value_fg, value_fg, prev_state, frame_size, check_f);
102+
activation(value_ig, value_ig, frame_size, cell_act, threads);
103+
activation(value_fg, value_fg, frame_size, cell_act, threads);
104+
vector_dot(state, value_in, value_ig, frame_size);
105+
vector_dot(state, state, prev_state, frame_size, value_fg);
106+
107+
for (int i = 0; i < frame_size; ++i) {
108+
if (cell_clip > 0.0) {
109+
if (state[i] < -1.0 * cell_clip) {
110+
state[i] = -1.0 * cell_clip;
111+
}
112+
if (state[i] > cell_clip) {
113+
state[i] = cell_clip;
114+
}
115+
}
116+
}
117+
118+
vector_dot(value_og, value_og, state, frame_size, check_o);
119+
activation(value_og, value_og, frame_size, cell_act, threads);
120+
activation(state, state_act, frame_size, cand_act, threads);
121+
vector_dot(value.output_value, value_og, state_act, frame_size);
122+
123+
value.gate_value += frame_size * 4;
124+
value.state_value += frame_size;
125+
value.state_active_value += frame_size;
126+
value.output_value += frame_size;
127+
if (value.prev_state_value) {
128+
value.prev_state_value += frame_size;
129+
}
130+
}
131+
}
132+
};
133+
134+
} // namespace math
135+
} // namespace arm
136+
} // namespace lite
137+
} // namespace paddle

lite/kernels/arm/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${
101101
add_kernel(fill_constant_compute_arm ARM basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm)
102102
add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm)
103103
add_kernel(is_empty_compute_arm ARM extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm)
104+
add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math_arm)
104105

105106

106107
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)

0 commit comments

Comments
 (0)