1
+ // Copyright 2024 Tencent
2
+ // SPDX-License-Identifier: BSD-3-Clause
3
+
4
+ #include " layernorm_riscv.h"
5
+ #include < math.h>
6
+
7
+ #if __riscv_vector
8
+ #include < riscv_vector.h>
9
+ #include " riscv_usability.h"
10
+ #endif // __riscv_vector
11
+
12
+ #include " cpu.h"
13
+
14
+ namespace ncnn {
15
+
16
+ LayerNorm_riscv::LayerNorm_riscv ()
17
+ {
18
+ #if __riscv_vector
19
+ support_packing = true ;
20
+ #endif // __riscv_vector
21
+ #if NCNN_ZFH
22
+ #if __riscv_vector
23
+ support_fp16_storage = cpu_support_riscv_zvfh ();
24
+ #else
25
+ support_fp16_storage = cpu_support_riscv_zfh ();
26
+ #endif
27
+ #endif
28
+ }
29
+
30
+ #if __riscv_vector
31
+ static inline int layernorm_rvv_pack1_procedure (int size, float * ptr, const float * gamma_data, const float * beta_data, float eps, int affine)
32
+ {
33
+ float sum = 0 .f ;
34
+ float sqsum = 0 .f ;
35
+ size_t vl_max = __riscv_vsetvlmax_e32m1 ();
36
+ vfloat32m1_t _sum = __riscv_vfmv_s_f_f32m1 (0 .f , vl_max);
37
+ vfloat32m1_t _sqsum = __riscv_vfmv_s_f_f32m1 (0 .f , vl_max);
38
+
39
+ {
40
+ int n = size;
41
+ float * ptr_sum = ptr;
42
+ while (n > 0 )
43
+ {
44
+ size_t vl = __riscv_vsetvl_e32m8 (n);
45
+ vfloat32m8_t _p = __riscv_vle32_v_f32m8 (ptr_sum, vl);
46
+ _sum = __riscv_vfredusum_vs_f32m8_f32m1 (_p, _sum, vl);
47
+ ptr_sum += vl;
48
+ n -= vl;
49
+ }
50
+ }
51
+
52
+ sum = __riscv_vfmv_f_s_f32m1_f32 (_sum);
53
+ float mean = sum / size;
54
+
55
+ {
56
+ int n = size;
57
+ float * ptr_sqsum = ptr;
58
+ while (n > 0 )
59
+ {
60
+ size_t vl = __riscv_vsetvl_e32m8 (n);
61
+ vfloat32m8_t _p = __riscv_vle32_v_f32m8 (ptr_sqsum, vl);
62
+ _p = __riscv_vfsub_vf_f32m8 (_p, mean, vl);
63
+ _sqsum = __riscv_vfredusum_vs_f32m8_f32m1 (__riscv_vfmul_vv_f32m8 (_p, _p, vl), _sqsum, vl);
64
+ n -= vl;
65
+ ptr_sqsum += vl;
66
+ }
67
+ }
68
+
69
+ sqsum = __riscv_vfmv_f_s_f32m1_f32 (_sqsum);
70
+ float var = sqsum / size;
71
+
72
+ float a = static_cast <float >(1 .f / (sqrt (var + eps)));
73
+ float b = -mean * a;
74
+
75
+ {
76
+ int n = size;
77
+ float * ptr_store = ptr;
78
+ const float * ptr_gamma = gamma_data;
79
+ const float * ptr_beta = beta_data;
80
+ if (affine)
81
+ {
82
+ while (n > 0 )
83
+ {
84
+ size_t vl = __riscv_vsetvl_e32m8 (n);
85
+ vfloat32m8_t _p = __riscv_vle32_v_f32m8 (ptr_store, vl);
86
+ _p = __riscv_vfmul_vf_f32m8 (_p, a, vl);
87
+ vfloat32m8_t _gamma = __riscv_vle32_v_f32m8 (ptr_gamma, vl);
88
+ _p = __riscv_vfadd_vf_f32m8 (_p, b, vl);
89
+ vfloat32m8_t _beta = __riscv_vle32_v_f32m8 (ptr_beta, vl);
90
+ _p = __riscv_vfmadd_vv_f32m8 (_p, _gamma, _beta, vl);
91
+ __riscv_vse32_v_f32m8 (ptr_store, _p, vl);
92
+
93
+ n -= vl;
94
+ ptr_store += vl;
95
+ ptr_gamma += vl;
96
+ ptr_beta += vl;
97
+ }
98
+ }
99
+ else
100
+ {
101
+ while (n > 0 )
102
+ {
103
+ size_t vl = __riscv_vsetvl_e32m8 (n);
104
+ vfloat32m8_t _p = __riscv_vle32_v_f32m8 (ptr_store, vl);
105
+ _p = __riscv_vfmul_vf_f32m8 (_p, a, vl);
106
+ _p = __riscv_vfadd_vf_f32m8 (_p, b, vl);
107
+ __riscv_vse32_v_f32m8 (ptr_store, _p, vl);
108
+ n -= vl;
109
+ ptr_store += vl;
110
+ }
111
+ }
112
+ }
113
+ return 0 ;
114
+ }
115
+
116
+ static inline int layernorm_rvv_packn_procedure (int size, float * ptr, const float * gamma_data, const float * beta_data, float eps, int affine, const size_t vl)
117
+ {
118
+ vfloat32m1_t _sum = __riscv_vfmv_v_f_f32m1 (0 .f , vl);
119
+ vfloat32m1_t _sqsum = __riscv_vfmv_v_f_f32m1 (0 .f , vl);
120
+ for (int i = 0 ; i < size; i++)
121
+ {
122
+ vfloat32m1_t _p = __riscv_vle32_v_f32m1 (ptr + vl * i, vl);
123
+ _sum = __riscv_vfadd_vv_f32m1 (_p, _sum, vl);
124
+ // _sqsum = vfmadd_vv_f32m1(_p,_p,_sqsum,vl);
125
+ }
126
+ vfloat32m1_t _mean = __riscv_vfdiv_vf_f32m1 (_sum, size, vl);
127
+ for (int i = 0 ; i < size; i++)
128
+ {
129
+ vfloat32m1_t _p = __riscv_vle32_v_f32m1 (ptr + vl * i, vl);
130
+ _p = __riscv_vfsub_vv_f32m1 (_p, _mean, vl);
131
+ _sqsum = __riscv_vfmacc_vv_f32m1 (_sqsum, _p, _p, vl);
132
+ }
133
+ vfloat32m1_t _var = __riscv_vfdiv_vf_f32m1 (_sqsum, size, vl);
134
+ vfloat32m1_t _a = __riscv_vfrdiv_vf_f32m1 (__riscv_vfsqrt_v_f32m1 (__riscv_vfadd_vf_f32m1 (_var, eps, vl), vl), 1 .f , vl);
135
+ vfloat32m1_t _b = __riscv_vfmul_vv_f32m1 (__riscv_vfsgnjn_vv_f32m1 (_mean, _mean, vl), _a, vl);
136
+ if (affine)
137
+ {
138
+ for (int i = 0 ; i < size; i++)
139
+ {
140
+ const int offset = vl * i;
141
+ vfloat32m1_t _p = __riscv_vle32_v_f32m1 (ptr + offset, vl);
142
+ _p = __riscv_vfmadd_vv_f32m1 (_p, _a, _b, vl);
143
+ _p = __riscv_vfmul_vf_f32m1 (_p, gamma_data[i], vl);
144
+ _p = __riscv_vfadd_vf_f32m1 (_p, beta_data[i], vl);
145
+ __riscv_vse32_v_f32m1 (ptr + offset, _p, vl);
146
+ }
147
+ }
148
+ else
149
+ {
150
+ for (int i = 0 ; i < size; i++)
151
+ {
152
+ const int offset = vl * i;
153
+ vfloat32m1_t _p = __riscv_vle32_v_f32m1 (ptr + offset, vl);
154
+ _p = __riscv_vfmadd_vv_f32m1 (_p, _a, _b, vl);
155
+ __riscv_vse32_v_f32m1 (ptr + offset, _p, vl);
156
+ }
157
+ }
158
+
159
+ return 0 ;
160
+ }
161
+ #else
162
+ static inline int layernorm_scalar_procedure (int size, float * ptr, const float * gamma_data, const float * beta_data, float eps, int affine)
163
+ {
164
+ // mean and var
165
+ float sum = 0 .f ;
166
+ float sqsum = 0 .f ;
167
+ for (int i = 0 ; i < size; i++) sum += ptr[i];
168
+
169
+ float mean = sum / size;
170
+ float tmp = 0 .f ;
171
+ for (int i = 0 ; i < size; i++)
172
+ {
173
+ tmp = ptr[i] - mean;
174
+ sqsum += tmp * tmp;
175
+ }
176
+
177
+ float var = sqsum / size;
178
+
179
+ float a = static_cast <float >(1 .f / (sqrt (var + eps)));
180
+ float b = -mean * a;
181
+
182
+ if (affine)
183
+ for (int i = 0 ; i < size; i++) ptr[i] = (ptr[i] * a + b) * gamma_data[i] + beta_data[i];
184
+ else
185
+ for (int i = 0 ; i < size; i++) ptr[i] = ptr[i] * a + b;
186
+
187
+ return 0 ;
188
+ }
189
+ #endif // __riscv_vector
190
+
191
+ int LayerNorm_riscv::forward_inplace (Mat& bottom_top_blob, const Option& opt) const
192
+ {
193
+ #if NCNN_ZFH
194
+ int elembits = bottom_top_blob.elembits ();
195
+ if (opt.use_fp16_storage && elembits == 16 )
196
+ {
197
+ if (opt.use_fp16_arithmetic )
198
+ return forward_inplace_fp16sa (bottom_top_blob, opt);
199
+ else
200
+ return forward_inplace_fp16s (bottom_top_blob, opt);
201
+ }
202
+ #endif // NCNN_ZFH
203
+
204
+ int elempack = bottom_top_blob.elempack ;
205
+ int dims = bottom_top_blob.dims ;
206
+ int w = bottom_top_blob.w ;
207
+
208
+ if (dims == 1 )
209
+ {
210
+ float * ptr = bottom_top_blob;
211
+ #if __riscv_vector
212
+ return layernorm_rvv_pack1_procedure (w * elempack, ptr, gamma_data, beta_data, eps, affine);
213
+ #else
214
+ return layernorm_scalar_procedure (w, ptr, gamma_data, beta_data, eps, affine);
215
+ #endif // __riscv_vector
216
+ }
217
+ #if __riscv_vector
218
+ if (elempack == 1 )
219
+ #endif
220
+ {
221
+ if (dims == 2 )
222
+ {
223
+ int w = bottom_top_blob.w ;
224
+ int h = bottom_top_blob.h ;
225
+ // assert affine_size == w
226
+ #pragma omp parallel for num_threads(opt.num_threads)
227
+ for (int i = 0 ; i < h; i++)
228
+ {
229
+ float * ptr = bottom_top_blob.row (i);
230
+ #if __riscv_vector
231
+ layernorm_rvv_pack1_procedure (w, ptr, gamma_data, beta_data, eps, affine);
232
+ #else
233
+ layernorm_scalar_procedure (w, ptr, gamma_data, beta_data, eps, affine);
234
+ #endif // __riscv_vector
235
+ }
236
+ }
237
+
238
+ if (dims == 3 )
239
+ {
240
+ int w = bottom_top_blob.w ;
241
+ int h = bottom_top_blob.h ;
242
+ int channels = bottom_top_blob.c ;
243
+ int size = w * h;
244
+ if (affine_size == w)
245
+ {
246
+ #pragma omp parallel for num_threads(opt.num_threads)
247
+ for (int q = 0 ; q < channels; q++)
248
+ {
249
+ for (int i = 0 ; i < h; i++)
250
+ {
251
+ float * ptr = bottom_top_blob.channel (q).row (i);
252
+ #if __riscv_vector
253
+ layernorm_rvv_pack1_procedure (w, ptr, gamma_data, beta_data, eps, affine);
254
+ #else
255
+ layernorm_scalar_procedure (w, ptr, gamma_data, beta_data, eps, affine);
256
+ #endif // __riscv_vector
257
+ }
258
+ }
259
+ }
260
+ else // if (affine_size == size)
261
+ {
262
+ #pragma omp parallel for num_threads(opt.num_threads)
263
+ for (int q = 0 ; q < channels; q++)
264
+ {
265
+ float * ptr = bottom_top_blob.channel (q);
266
+ #if __riscv_vector
267
+ layernorm_rvv_pack1_procedure (size, ptr, gamma_data, beta_data, eps, affine);
268
+ #else
269
+ layernorm_scalar_procedure (size, ptr, gamma_data, beta_data, eps, affine);
270
+ #endif // __riscv_vector
271
+ }
272
+ }
273
+ }
274
+ }
275
+
276
+ #if __riscv_vector
277
+ const int packn = csrr_vlenb () / 4 ;
278
+ if (elempack == packn)
279
+ {
280
+ const size_t vl = __riscv_vsetvl_e32m1 (packn);
281
+ if (dims == 2 )
282
+ {
283
+ int w = bottom_top_blob.w ;
284
+ int h = bottom_top_blob.h ;
285
+ // assert affine_size == w
286
+
287
+ #pragma omp parallel for num_threads(opt.num_threads)
288
+ for (int i = 0 ; i < h; i++)
289
+ {
290
+ float * ptr = bottom_top_blob.row (i);
291
+ layernorm_rvv_packn_procedure (w, ptr, gamma_data, beta_data, eps, affine, vl);
292
+ }
293
+ }
294
+ if (dims == 3 )
295
+ {
296
+ int w = bottom_top_blob.w ;
297
+ int h = bottom_top_blob.h ;
298
+ int channels = bottom_top_blob.c ;
299
+ int size = w * h;
300
+
301
+ if (affine_size == w)
302
+ {
303
+ #pragma omp parallel for num_threads(opt.num_threads)
304
+ for (int q = 0 ; q < channels; q++)
305
+ {
306
+ for (int i = 0 ; i < h; i++)
307
+ {
308
+ float * ptr = bottom_top_blob.channel (q).row (i);
309
+
310
+ layernorm_rvv_packn_procedure (w, ptr, gamma_data, beta_data, eps, affine, vl);
311
+ }
312
+ }
313
+ }
314
+ else // if (affine_size == size)
315
+ {
316
+ #pragma omp parallel for num_threads(opt.num_threads)
317
+ for (int q = 0 ; q < channels; q++)
318
+ {
319
+ float * ptr = bottom_top_blob.channel (q);
320
+ layernorm_rvv_packn_procedure (size, ptr, gamma_data, beta_data, eps, affine, vl);
321
+ }
322
+ }
323
+ }
324
+ }
325
+ #endif // __riscv_vector
326
+ return 0 ;
327
+ }
328
+ } // namespace ncnn
0 commit comments