Skip to content

Commit d01715a

Browse files
authored
[transformer] add rms-norm (#2396)
* [transformer] add rms-norm * fix assert
1 parent e5fd5c0 commit d01715a

File tree

7 files changed

+82
-25
lines changed

7 files changed

+82
-25
lines changed

wenet/transformer/convolution.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import torch
2020
from torch import nn
2121

22+
from wenet.utils.class_utils import WENET_NORM_CLASSES
23+
2224

2325
class ConvolutionModule(nn.Module):
2426
"""ConvolutionModule in Conformer model."""
@@ -68,13 +70,13 @@ def __init__(self,
6870
bias=bias,
6971
)
7072

71-
assert norm in ['batch_norm', 'layer_norm']
73+
assert norm in ['batch_norm', 'layer_norm', 'rms_norm']
7274
if norm == "batch_norm":
7375
self.use_layer_norm = False
74-
self.norm = nn.BatchNorm1d(channels)
76+
self.norm = WENET_NORM_CLASSES['batch_norm'](channels)
7577
else:
7678
self.use_layer_norm = True
77-
self.norm = nn.LayerNorm(channels)
79+
self.norm = WENET_NORM_CLASSES[norm](channels)
7880

7981
self.pointwise_conv2 = nn.Conv1d(
8082
channels,

wenet/transformer/decoder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
WENET_ATTENTION_CLASSES,
2626
WENET_ACTIVATION_CLASSES,
2727
WENET_MLP_CLASSES,
28+
WENET_NORM_CLASSES,
2829
)
2930
from wenet.utils.common import mask_to_bias
3031
from wenet.utils.mask import (subsequent_mask, make_pad_mask)
@@ -81,6 +82,7 @@ def __init__(
8182
tie_word_embedding: bool = False,
8283
use_sdpa: bool = False,
8384
mlp_type: str = 'position_wise_feed_forward',
85+
layer_norm_type: str = 'layer_norm',
8486
):
8587
super().__init__()
8688
attention_dim = encoder_output_size
@@ -93,8 +95,10 @@ def __init__(
9395
positional_dropout_rate),
9496
)
9597

98+
assert layer_norm_type in ['layer_norm', 'rms_norm']
9699
self.normalize_before = normalize_before
97-
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
100+
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](attention_dim,
101+
eps=1e-5)
98102
self.use_output_layer = use_output_layer
99103
if use_output_layer:
100104
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)

wenet/transformer/decoder_layer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import torch
1919
from torch import nn
2020

21+
from wenet.utils.class_utils import WENET_NORM_CLASSES
22+
2123

2224
class DecoderLayer(nn.Module):
2325
"""Single decoder layer module.
@@ -46,16 +48,18 @@ def __init__(
4648
feed_forward: nn.Module,
4749
dropout_rate: float,
4850
normalize_before: bool = True,
51+
layer_norm_type: str = 'layer_norm',
4952
):
5053
"""Construct an DecoderLayer object."""
5154
super().__init__()
5255
self.size = size
5356
self.self_attn = self_attn
5457
self.src_attn = src_attn
5558
self.feed_forward = feed_forward
56-
self.norm1 = nn.LayerNorm(size, eps=1e-5)
57-
self.norm2 = nn.LayerNorm(size, eps=1e-5)
58-
self.norm3 = nn.LayerNorm(size, eps=1e-5)
59+
assert layer_norm_type in ['layer_norm', 'rms_norm']
60+
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
61+
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
62+
self.norm3 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
5963
self.dropout = nn.Dropout(dropout_rate)
6064
self.normalize_before = normalize_before
6165

wenet/transformer/encoder.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from wenet.utils.class_utils import (
2626
WENET_EMB_CLASSES,
2727
WENET_MLP_CLASSES,
28+
WENET_NORM_CLASSES,
2829
WENET_SUBSAMPLE_CLASSES,
2930
WENET_ATTENTION_CLASSES,
3031
WENET_ACTIVATION_CLASSES,
@@ -55,6 +56,7 @@ def __init__(
5556
use_dynamic_left_chunk: bool = False,
5657
gradient_checkpointing: bool = False,
5758
use_sdpa: bool = False,
59+
layer_norm_type: str = 'layer_norm',
5860
):
5961
"""
6062
Args:
@@ -102,8 +104,10 @@ def __init__(
102104
positional_dropout_rate),
103105
)
104106

107+
assert layer_norm_type in ['layer_norm', 'rms_norm']
105108
self.normalize_before = normalize_before
106-
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
109+
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](output_size,
110+
eps=1e-5)
107111
self.static_chunk_size = static_chunk_size
108112
self.use_dynamic_chunk = use_dynamic_chunk
109113
self.use_dynamic_left_chunk = use_dynamic_left_chunk
@@ -368,6 +372,7 @@ def __init__(
368372
gradient_checkpointing: bool = False,
369373
use_sdpa: bool = False,
370374
mlp_type: str = 'position_wise_feed_forward',
375+
layer_norm_type: str = 'layer_norm',
371376
):
372377
""" Construct TransformerEncoder
373378
@@ -379,19 +384,21 @@ def __init__(
379384
input_layer, pos_enc_layer_type, normalize_before,
380385
static_chunk_size, use_dynamic_chunk, global_cmvn,
381386
use_dynamic_left_chunk, gradient_checkpointing,
382-
use_sdpa)
387+
use_sdpa, layer_norm_type)
383388
activation = WENET_ACTIVATION_CLASSES[activation_type]()
384389
mlp_class = WENET_MLP_CLASSES[mlp_type]
385390
self.encoders = torch.nn.ModuleList([
386-
TransformerEncoderLayer(
387-
output_size,
388-
WENET_ATTENTION_CLASSES["selfattn"](attention_heads,
389-
output_size,
390-
attention_dropout_rate,
391-
query_bias, key_bias,
392-
value_bias, use_sdpa),
393-
mlp_class(output_size, linear_units, dropout_rate, activation,
394-
mlp_bias), dropout_rate, normalize_before)
391+
TransformerEncoderLayer(output_size,
392+
WENET_ATTENTION_CLASSES["selfattn"](
393+
attention_heads, output_size,
394+
attention_dropout_rate, query_bias,
395+
key_bias, value_bias, use_sdpa),
396+
mlp_class(output_size, linear_units,
397+
dropout_rate, activation,
398+
mlp_bias),
399+
dropout_rate,
400+
normalize_before,
401+
layer_norm_type=layer_norm_type)
395402
for _ in range(num_blocks)
396403
])
397404

wenet/transformer/encoder_layer.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import torch
2121
from torch import nn
2222

23+
from wenet.utils.class_utils import WENET_NORM_CLASSES
24+
2325

2426
class TransformerEncoderLayer(nn.Module):
2527
"""Encoder layer module.
@@ -44,13 +46,15 @@ def __init__(
4446
feed_forward: torch.nn.Module,
4547
dropout_rate: float,
4648
normalize_before: bool = True,
49+
layer_norm_type: str = 'layer_norm',
4750
):
4851
"""Construct an EncoderLayer object."""
4952
super().__init__()
5053
self.self_attn = self_attn
5154
self.feed_forward = feed_forward
52-
self.norm1 = nn.LayerNorm(size, eps=1e-5)
53-
self.norm2 = nn.LayerNorm(size, eps=1e-5)
55+
assert layer_norm_type in ['layer_norm', 'rms_norm']
56+
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
57+
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
5458
self.dropout = nn.Dropout(dropout_rate)
5559
self.size = size
5660
self.normalize_before = normalize_before
@@ -135,23 +139,29 @@ def __init__(
135139
conv_module: Optional[nn.Module] = None,
136140
dropout_rate: float = 0.1,
137141
normalize_before: bool = True,
142+
layer_norm_type: str = 'layer_norm',
138143
):
139144
"""Construct an EncoderLayer object."""
140145
super().__init__()
141146
self.self_attn = self_attn
142147
self.feed_forward = feed_forward
148+
assert layer_norm_type in ['layer_norm', 'rms_norm']
143149
self.feed_forward_macaron = feed_forward_macaron
144150
self.conv_module = conv_module
145-
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
146-
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
151+
self.norm_ff = WENET_NORM_CLASSES[layer_norm_type](
152+
size, eps=1e-5) # for the FNN module
153+
self.norm_mha = WENET_NORM_CLASSES[layer_norm_type](
154+
size, eps=1e-5) # for the MHA module
147155
if feed_forward_macaron is not None:
148-
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
156+
self.norm_ff_macaron = WENET_NORM_CLASSES[layer_norm_type](
157+
size, eps=1e-5)
149158
self.ff_scale = 0.5
150159
else:
151160
self.ff_scale = 1.0
152161
if self.conv_module is not None:
153-
self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
154-
self.norm_final = nn.LayerNorm(
162+
self.norm_conv = WENET_NORM_CLASSES[layer_norm_type](
163+
size, eps=1e-5) # for the CNN module
164+
self.norm_final = WENET_NORM_CLASSES[layer_norm_type](
155165
size, eps=1e-5) # for the final output of the block
156166
self.dropout = nn.Dropout(dropout_rate)
157167
self.size = size

wenet/transformer/norm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
3+
4+
class RMSNorm(torch.nn.Module):
5+
""" https://arxiv.org/pdf/1910.07467.pdf
6+
"""
7+
8+
def __init__(
9+
self,
10+
dim: int,
11+
eps: float = 1e-6,
12+
):
13+
super().__init__()
14+
self.eps = eps
15+
self.weight = torch.nn.Parameter(torch.ones(dim))
16+
17+
def _norm(self, x):
18+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
19+
20+
def forward(self, x):
21+
x = self._norm(x.float()).type_as(x)
22+
return x * self.weight

wenet/utils/class_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# -*- coding: utf-8 -*-
33
# Copyright [2023-11-28] <[email protected], Xingchen Song>
44
import torch
5+
from torch.nn import BatchNorm1d, LayerNorm
56
from wenet.paraformer.embedding import ParaformerPositinoalEncoding
7+
from wenet.transformer.norm import RMSNorm
68
from wenet.transformer.positionwise_feed_forward import (
79
GatedVariantsMLP, MoEFFNLayer, PositionwiseFeedForward)
810

@@ -77,3 +79,9 @@
7779
'moe': MoEFFNLayer,
7880
'gated': GatedVariantsMLP
7981
}
82+
83+
WENET_NORM_CLASSES = {
84+
'layer_norm': LayerNorm,
85+
'batch_norm': BatchNorm1d,
86+
'rms_norm': RMSNorm
87+
}

0 commit comments

Comments
 (0)