Skip to content

Commit b4aa587

Browse files
committed
fix: revert layernorm(cpu)
1 parent 1cf9f0d commit b4aa587

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

src/layer/layernorm.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ int LayerNorm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
8888
float* ptr = bottom_top_blob;
8989
layernorm(ptr, gamma_data, beta_data, eps, w);
9090
}
91-
else if (dims == 2)
91+
92+
if (dims == 2)
9293
{
9394
int w = bottom_top_blob.w;
9495
int h = bottom_top_blob.h;
@@ -101,35 +102,32 @@ int LayerNorm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
101102
layernorm(ptr, gamma_data, beta_data, eps, w);
102103
}
103104
}
104-
else if (dims == 3)
105+
106+
if (dims == 3)
105107
{
106108
int w = bottom_top_blob.w;
107109
int h = bottom_top_blob.h;
108110
int channels = bottom_top_blob.c;
109111

110-
int group_size;
111-
int num_groups_per_channel;
112-
113112
if (affine_size == w)
114113
{
115-
group_size = w;
116-
num_groups_per_channel = h;
117-
}
118-
else // if (affine_size == w * h)
119-
{
120-
group_size = w * h;
121-
num_groups_per_channel = 1;
114+
#pragma omp parallel for num_threads(opt.num_threads)
115+
for (int q = 0; q < channels; q++)
116+
{
117+
for (int i = 0; i < h; i++)
118+
{
119+
float* ptr = bottom_top_blob.channel(q).row(i);
120+
layernorm(ptr, gamma_data, beta_data, eps, w);
121+
}
122+
}
122123
}
123-
124-
#pragma omp parallel for num_threads(opt.num_threads)
125-
for (int q = 0; q < channels; q++)
124+
else // if (affine_size == size)
126125
{
127-
float* channel_ptr = bottom_top_blob.channel(q);
128-
129-
for (int i = 0; i < num_groups_per_channel; i++)
126+
#pragma omp parallel for num_threads(opt.num_threads)
127+
for (int q = 0; q < channels; q++)
130128
{
131-
float* ptr = channel_ptr + i * group_size;
132-
layernorm(ptr, gamma_data, beta_data, eps, group_size);
129+
float* ptr = bottom_top_blob.channel(q);
130+
layernorm(ptr, gamma_data, beta_data, eps, w * h);
133131
}
134132
}
135133
}

src/layer/layernorm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ class LayerNorm : public Layer
3232

3333
} // namespace ncnn
3434

35-
#endif // LAYER_LAYERNORM_H
35+
#endif // LAYER_LAYERNORM_H

0 commit comments

Comments
 (0)