Skip to content

Commit 29ab21a

Browse files
Add Layernorm rvv1.0 optimization (#6225)
1 parent 1289458 commit 29ab21a

File tree

3 files changed

+937
-0
lines changed

3 files changed

+937
-0
lines changed

src/layer/riscv/layernorm_riscv.cpp

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
// Copyright 2024 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#include "layernorm_riscv.h"
5+
#include <math.h>
6+
7+
#if __riscv_vector
8+
#include <riscv_vector.h>
9+
#include "riscv_usability.h"
10+
#endif // __riscv_vector
11+
12+
#include "cpu.h"
13+
14+
namespace ncnn {
15+
16+
LayerNorm_riscv::LayerNorm_riscv()
17+
{
18+
#if __riscv_vector
19+
support_packing = true;
20+
#endif // __riscv_vector
21+
#if NCNN_ZFH
22+
#if __riscv_vector
23+
support_fp16_storage = cpu_support_riscv_zvfh();
24+
#else
25+
support_fp16_storage = cpu_support_riscv_zfh();
26+
#endif
27+
#endif
28+
}
29+
30+
#if __riscv_vector
31+
static inline int layernorm_rvv_pack1_procedure(int size, float* ptr, const float* gamma_data, const float* beta_data, float eps, int affine)
32+
{
33+
float sum = 0.f;
34+
float sqsum = 0.f;
35+
size_t vl_max = __riscv_vsetvlmax_e32m1();
36+
vfloat32m1_t _sum = __riscv_vfmv_s_f_f32m1(0.f, vl_max);
37+
vfloat32m1_t _sqsum = __riscv_vfmv_s_f_f32m1(0.f, vl_max);
38+
39+
{
40+
int n = size;
41+
float* ptr_sum = ptr;
42+
while (n > 0)
43+
{
44+
size_t vl = __riscv_vsetvl_e32m8(n);
45+
vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr_sum, vl);
46+
_sum = __riscv_vfredusum_vs_f32m8_f32m1(_p, _sum, vl);
47+
ptr_sum += vl;
48+
n -= vl;
49+
}
50+
}
51+
52+
sum = __riscv_vfmv_f_s_f32m1_f32(_sum);
53+
float mean = sum / size;
54+
55+
{
56+
int n = size;
57+
float* ptr_sqsum = ptr;
58+
while (n > 0)
59+
{
60+
size_t vl = __riscv_vsetvl_e32m8(n);
61+
vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr_sqsum, vl);
62+
_p = __riscv_vfsub_vf_f32m8(_p, mean, vl);
63+
_sqsum = __riscv_vfredusum_vs_f32m8_f32m1(__riscv_vfmul_vv_f32m8(_p, _p, vl), _sqsum, vl);
64+
n -= vl;
65+
ptr_sqsum += vl;
66+
}
67+
}
68+
69+
sqsum = __riscv_vfmv_f_s_f32m1_f32(_sqsum);
70+
float var = sqsum / size;
71+
72+
float a = static_cast<float>(1.f / (sqrt(var + eps)));
73+
float b = -mean * a;
74+
75+
{
76+
int n = size;
77+
float* ptr_store = ptr;
78+
const float* ptr_gamma = gamma_data;
79+
const float* ptr_beta = beta_data;
80+
if (affine)
81+
{
82+
while (n > 0)
83+
{
84+
size_t vl = __riscv_vsetvl_e32m8(n);
85+
vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr_store, vl);
86+
_p = __riscv_vfmul_vf_f32m8(_p, a, vl);
87+
vfloat32m8_t _gamma = __riscv_vle32_v_f32m8(ptr_gamma, vl);
88+
_p = __riscv_vfadd_vf_f32m8(_p, b, vl);
89+
vfloat32m8_t _beta = __riscv_vle32_v_f32m8(ptr_beta, vl);
90+
_p = __riscv_vfmadd_vv_f32m8(_p, _gamma, _beta, vl);
91+
__riscv_vse32_v_f32m8(ptr_store, _p, vl);
92+
93+
n -= vl;
94+
ptr_store += vl;
95+
ptr_gamma += vl;
96+
ptr_beta += vl;
97+
}
98+
}
99+
else
100+
{
101+
while (n > 0)
102+
{
103+
size_t vl = __riscv_vsetvl_e32m8(n);
104+
vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr_store, vl);
105+
_p = __riscv_vfmul_vf_f32m8(_p, a, vl);
106+
_p = __riscv_vfadd_vf_f32m8(_p, b, vl);
107+
__riscv_vse32_v_f32m8(ptr_store, _p, vl);
108+
n -= vl;
109+
ptr_store += vl;
110+
}
111+
}
112+
}
113+
return 0;
114+
}
115+
116+
static inline int layernorm_rvv_packn_procedure(int size, float* ptr, const float* gamma_data, const float* beta_data, float eps, int affine, const size_t vl)
117+
{
118+
vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1(0.f, vl);
119+
vfloat32m1_t _sqsum = __riscv_vfmv_v_f_f32m1(0.f, vl);
120+
for (int i = 0; i < size; i++)
121+
{
122+
vfloat32m1_t _p = __riscv_vle32_v_f32m1(ptr + vl * i, vl);
123+
_sum = __riscv_vfadd_vv_f32m1(_p, _sum, vl);
124+
// _sqsum = vfmadd_vv_f32m1(_p,_p,_sqsum,vl);
125+
}
126+
vfloat32m1_t _mean = __riscv_vfdiv_vf_f32m1(_sum, size, vl);
127+
for (int i = 0; i < size; i++)
128+
{
129+
vfloat32m1_t _p = __riscv_vle32_v_f32m1(ptr + vl * i, vl);
130+
_p = __riscv_vfsub_vv_f32m1(_p, _mean, vl);
131+
_sqsum = __riscv_vfmacc_vv_f32m1(_sqsum, _p, _p, vl);
132+
}
133+
vfloat32m1_t _var = __riscv_vfdiv_vf_f32m1(_sqsum, size, vl);
134+
vfloat32m1_t _a = __riscv_vfrdiv_vf_f32m1(__riscv_vfsqrt_v_f32m1(__riscv_vfadd_vf_f32m1(_var, eps, vl), vl), 1.f, vl);
135+
vfloat32m1_t _b = __riscv_vfmul_vv_f32m1(__riscv_vfsgnjn_vv_f32m1(_mean, _mean, vl), _a, vl);
136+
if (affine)
137+
{
138+
for (int i = 0; i < size; i++)
139+
{
140+
const int offset = vl * i;
141+
vfloat32m1_t _p = __riscv_vle32_v_f32m1(ptr + offset, vl);
142+
_p = __riscv_vfmadd_vv_f32m1(_p, _a, _b, vl);
143+
_p = __riscv_vfmul_vf_f32m1(_p, gamma_data[i], vl);
144+
_p = __riscv_vfadd_vf_f32m1(_p, beta_data[i], vl);
145+
__riscv_vse32_v_f32m1(ptr + offset, _p, vl);
146+
}
147+
}
148+
else
149+
{
150+
for (int i = 0; i < size; i++)
151+
{
152+
const int offset = vl * i;
153+
vfloat32m1_t _p = __riscv_vle32_v_f32m1(ptr + offset, vl);
154+
_p = __riscv_vfmadd_vv_f32m1(_p, _a, _b, vl);
155+
__riscv_vse32_v_f32m1(ptr + offset, _p, vl);
156+
}
157+
}
158+
159+
return 0;
160+
}
161+
#else
162+
static inline int layernorm_scalar_procedure(int size, float* ptr, const float* gamma_data, const float* beta_data, float eps, int affine)
163+
{
164+
// mean and var
165+
float sum = 0.f;
166+
float sqsum = 0.f;
167+
for (int i = 0; i < size; i++) sum += ptr[i];
168+
169+
float mean = sum / size;
170+
float tmp = 0.f;
171+
for (int i = 0; i < size; i++)
172+
{
173+
tmp = ptr[i] - mean;
174+
sqsum += tmp * tmp;
175+
}
176+
177+
float var = sqsum / size;
178+
179+
float a = static_cast<float>(1.f / (sqrt(var + eps)));
180+
float b = -mean * a;
181+
182+
if (affine)
183+
for (int i = 0; i < size; i++) ptr[i] = (ptr[i] * a + b) * gamma_data[i] + beta_data[i];
184+
else
185+
for (int i = 0; i < size; i++) ptr[i] = ptr[i] * a + b;
186+
187+
return 0;
188+
}
189+
#endif // __riscv_vector
190+
191+
int LayerNorm_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
192+
{
193+
#if NCNN_ZFH
194+
int elembits = bottom_top_blob.elembits();
195+
if (opt.use_fp16_storage && elembits == 16)
196+
{
197+
if (opt.use_fp16_arithmetic)
198+
return forward_inplace_fp16sa(bottom_top_blob, opt);
199+
else
200+
return forward_inplace_fp16s(bottom_top_blob, opt);
201+
}
202+
#endif // NCNN_ZFH
203+
204+
int elempack = bottom_top_blob.elempack;
205+
int dims = bottom_top_blob.dims;
206+
int w = bottom_top_blob.w;
207+
208+
if (dims == 1)
209+
{
210+
float* ptr = bottom_top_blob;
211+
#if __riscv_vector
212+
return layernorm_rvv_pack1_procedure(w * elempack, ptr, gamma_data, beta_data, eps, affine);
213+
#else
214+
return layernorm_scalar_procedure(w, ptr, gamma_data, beta_data, eps, affine);
215+
#endif // __riscv_vector
216+
}
217+
#if __riscv_vector
218+
if (elempack == 1)
219+
#endif
220+
{
221+
if (dims == 2)
222+
{
223+
int w = bottom_top_blob.w;
224+
int h = bottom_top_blob.h;
225+
// assert affine_size == w
226+
#pragma omp parallel for num_threads(opt.num_threads)
227+
for (int i = 0; i < h; i++)
228+
{
229+
float* ptr = bottom_top_blob.row(i);
230+
#if __riscv_vector
231+
layernorm_rvv_pack1_procedure(w, ptr, gamma_data, beta_data, eps, affine);
232+
#else
233+
layernorm_scalar_procedure(w, ptr, gamma_data, beta_data, eps, affine);
234+
#endif // __riscv_vector
235+
}
236+
}
237+
238+
if (dims == 3)
239+
{
240+
int w = bottom_top_blob.w;
241+
int h = bottom_top_blob.h;
242+
int channels = bottom_top_blob.c;
243+
int size = w * h;
244+
if (affine_size == w)
245+
{
246+
#pragma omp parallel for num_threads(opt.num_threads)
247+
for (int q = 0; q < channels; q++)
248+
{
249+
for (int i = 0; i < h; i++)
250+
{
251+
float* ptr = bottom_top_blob.channel(q).row(i);
252+
#if __riscv_vector
253+
layernorm_rvv_pack1_procedure(w, ptr, gamma_data, beta_data, eps, affine);
254+
#else
255+
layernorm_scalar_procedure(w, ptr, gamma_data, beta_data, eps, affine);
256+
#endif // __riscv_vector
257+
}
258+
}
259+
}
260+
else // if (affine_size == size)
261+
{
262+
#pragma omp parallel for num_threads(opt.num_threads)
263+
for (int q = 0; q < channels; q++)
264+
{
265+
float* ptr = bottom_top_blob.channel(q);
266+
#if __riscv_vector
267+
layernorm_rvv_pack1_procedure(size, ptr, gamma_data, beta_data, eps, affine);
268+
#else
269+
layernorm_scalar_procedure(size, ptr, gamma_data, beta_data, eps, affine);
270+
#endif // __riscv_vector
271+
}
272+
}
273+
}
274+
}
275+
276+
#if __riscv_vector
277+
const int packn = csrr_vlenb() / 4;
278+
if (elempack == packn)
279+
{
280+
const size_t vl = __riscv_vsetvl_e32m1(packn);
281+
if (dims == 2)
282+
{
283+
int w = bottom_top_blob.w;
284+
int h = bottom_top_blob.h;
285+
// assert affine_size == w
286+
287+
#pragma omp parallel for num_threads(opt.num_threads)
288+
for (int i = 0; i < h; i++)
289+
{
290+
float* ptr = bottom_top_blob.row(i);
291+
layernorm_rvv_packn_procedure(w, ptr, gamma_data, beta_data, eps, affine, vl);
292+
}
293+
}
294+
if (dims == 3)
295+
{
296+
int w = bottom_top_blob.w;
297+
int h = bottom_top_blob.h;
298+
int channels = bottom_top_blob.c;
299+
int size = w * h;
300+
301+
if (affine_size == w)
302+
{
303+
#pragma omp parallel for num_threads(opt.num_threads)
304+
for (int q = 0; q < channels; q++)
305+
{
306+
for (int i = 0; i < h; i++)
307+
{
308+
float* ptr = bottom_top_blob.channel(q).row(i);
309+
310+
layernorm_rvv_packn_procedure(w, ptr, gamma_data, beta_data, eps, affine, vl);
311+
}
312+
}
313+
}
314+
else // if (affine_size == size)
315+
{
316+
#pragma omp parallel for num_threads(opt.num_threads)
317+
for (int q = 0; q < channels; q++)
318+
{
319+
float* ptr = bottom_top_blob.channel(q);
320+
layernorm_rvv_packn_procedure(size, ptr, gamma_data, beta_data, eps, affine, vl);
321+
}
322+
}
323+
}
324+
}
325+
#endif // __riscv_vector
326+
return 0;
327+
}
328+
} // namespace ncnn

src/layer/riscv/layernorm_riscv.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright 2024 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#ifndef LAYER_LAYERNORM_RISCV_H
5+
#define LAYER_LAYERNORM_RISCV_H
6+
7+
#include "layernorm.h"
8+
9+
namespace ncnn {
10+
class LayerNorm_riscv : public LayerNorm
11+
{
12+
public:
13+
LayerNorm_riscv();
14+
15+
virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const;
16+
17+
protected:
18+
#if NCNN_ZFH
19+
int forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const;
20+
int forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const;
21+
#endif
22+
};
23+
24+
} // namespace ncnn
25+
26+
#endif // LAYER_LAYERNORM_RISCV_H

0 commit comments

Comments
 (0)