@@ -88,7 +88,8 @@ int LayerNorm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
88
88
float * ptr = bottom_top_blob;
89
89
layernorm (ptr, gamma_data, beta_data, eps, w);
90
90
}
91
- else if (dims == 2 )
91
+
92
+ if (dims == 2 )
92
93
{
93
94
int w = bottom_top_blob.w ;
94
95
int h = bottom_top_blob.h ;
@@ -101,35 +102,32 @@ int LayerNorm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
101
102
layernorm (ptr, gamma_data, beta_data, eps, w);
102
103
}
103
104
}
104
- else if (dims == 3 )
105
+
106
+ if (dims == 3 )
105
107
{
106
108
int w = bottom_top_blob.w ;
107
109
int h = bottom_top_blob.h ;
108
110
int channels = bottom_top_blob.c ;
109
111
110
- int group_size;
111
- int num_groups_per_channel;
112
-
113
112
if (affine_size == w)
114
113
{
115
- group_size = w;
116
- num_groups_per_channel = h;
117
- }
118
- else // if (affine_size == w * h)
119
- {
120
- group_size = w * h;
121
- num_groups_per_channel = 1 ;
114
+ #pragma omp parallel for num_threads(opt.num_threads)
115
+ for (int q = 0 ; q < channels; q++)
116
+ {
117
+ for (int i = 0 ; i < h; i++)
118
+ {
119
+ float * ptr = bottom_top_blob.channel (q).row (i);
120
+ layernorm (ptr, gamma_data, beta_data, eps, w);
121
+ }
122
+ }
122
123
}
123
-
124
- #pragma omp parallel for num_threads(opt.num_threads)
125
- for (int q = 0 ; q < channels; q++)
124
+ else // if (affine_size == size)
126
125
{
127
- float * channel_ptr = bottom_top_blob.channel (q);
128
-
129
- for (int i = 0 ; i < num_groups_per_channel; i++)
126
+ #pragma omp parallel for num_threads(opt.num_threads)
127
+ for (int q = 0 ; q < channels; q++)
130
128
{
131
- float * ptr = channel_ptr + i * group_size ;
132
- layernorm (ptr, gamma_data, beta_data, eps, group_size );
129
+ float * ptr = bottom_top_blob. channel (q) ;
130
+ layernorm (ptr, gamma_data, beta_data, eps, w * h );
133
131
}
134
132
}
135
133
}
0 commit comments