Skip to content

Commit a93af33

Browse files
authored
[transformer] add norm eps (#2397)
1 parent d01715a commit a93af33

File tree

5 files changed

+63
-37
lines changed

5 files changed

+63
-37
lines changed

wenet/transformer/convolution.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@
2525
class ConvolutionModule(nn.Module):
2626
"""ConvolutionModule in Conformer model."""
2727

28-
def __init__(self,
29-
channels: int,
30-
kernel_size: int = 15,
31-
activation: nn.Module = nn.ReLU(),
32-
norm: str = "batch_norm",
33-
causal: bool = False,
34-
bias: bool = True):
28+
def __init__(
29+
self,
30+
channels: int,
31+
kernel_size: int = 15,
32+
activation: nn.Module = nn.ReLU(),
33+
norm: str = "batch_norm",
34+
causal: bool = False,
35+
bias: bool = True,
36+
norm_eps: float = 1e-5,
37+
):
3538
"""Construct an ConvolutionModule object.
3639
Args:
3740
channels (int): The number of channels of conv layers.
@@ -73,10 +76,11 @@ def __init__(self,
7376
assert norm in ['batch_norm', 'layer_norm', 'rms_norm']
7477
if norm == "batch_norm":
7578
self.use_layer_norm = False
76-
self.norm = WENET_NORM_CLASSES['batch_norm'](channels)
79+
self.norm = WENET_NORM_CLASSES['batch_norm'](channels,
80+
eps=norm_eps)
7781
else:
7882
self.use_layer_norm = True
79-
self.norm = WENET_NORM_CLASSES[norm](channels)
83+
self.norm = WENET_NORM_CLASSES[norm](channels, eps=norm_eps)
8084

8185
self.pointwise_conv2 = nn.Conv1d(
8286
channels,

wenet/transformer/decoder.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
use_sdpa: bool = False,
8484
mlp_type: str = 'position_wise_feed_forward',
8585
layer_norm_type: str = 'layer_norm',
86+
norm_eps: float = 1e-5,
8687
):
8788
super().__init__()
8889
attention_dim = encoder_output_size
@@ -98,7 +99,7 @@ def __init__(
9899
assert layer_norm_type in ['layer_norm', 'rms_norm']
99100
self.normalize_before = normalize_before
100101
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](attention_dim,
101-
eps=1e-5)
102+
eps=norm_eps)
102103
self.use_output_layer = use_output_layer
103104
if use_output_layer:
104105
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
@@ -122,6 +123,8 @@ def __init__(
122123
activation, mlp_bias),
123124
dropout_rate,
124125
normalize_before,
126+
layer_norm_type,
127+
norm_eps,
125128
) for _ in range(self.num_blocks)
126129
])
127130

@@ -329,6 +332,8 @@ def __init__(
329332
gradient_checkpointing: bool = False,
330333
tie_word_embedding: bool = False,
331334
use_sdpa: bool = False,
335+
layer_norm_type: str = 'layer_norm',
336+
norm_eps: float = 1e-5,
332337
):
333338

334339
super().__init__()
@@ -352,7 +357,10 @@ def __init__(
352357
value_bias=value_bias,
353358
gradient_checkpointing=gradient_checkpointing,
354359
tie_word_embedding=tie_word_embedding,
355-
use_sdpa=use_sdpa)
360+
use_sdpa=use_sdpa,
361+
layer_norm_type=layer_norm_type,
362+
norm_eps=norm_eps,
363+
)
356364

357365
self.right_decoder = TransformerDecoder(
358366
vocab_size,
@@ -373,7 +381,10 @@ def __init__(
373381
mlp_bias=mlp_bias,
374382
gradient_checkpointing=gradient_checkpointing,
375383
tie_word_embedding=tie_word_embedding,
376-
use_sdpa=use_sdpa)
384+
use_sdpa=use_sdpa,
385+
layer_norm_type=layer_norm_type,
386+
norm_eps=norm_eps,
387+
)
377388

378389
def forward(
379390
self,

wenet/transformer/decoder_layer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
dropout_rate: float,
5050
normalize_before: bool = True,
5151
layer_norm_type: str = 'layer_norm',
52+
norm_eps: float = 1e-5,
5253
):
5354
"""Construct an DecoderLayer object."""
5455
super().__init__()
@@ -57,9 +58,9 @@ def __init__(
5758
self.src_attn = src_attn
5859
self.feed_forward = feed_forward
5960
assert layer_norm_type in ['layer_norm', 'rms_norm']
60-
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
61-
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
62-
self.norm3 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
61+
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps)
62+
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps)
63+
self.norm3 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps)
6364
self.dropout = nn.Dropout(dropout_rate)
6465
self.normalize_before = normalize_before
6566

wenet/transformer/encoder.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
gradient_checkpointing: bool = False,
5858
use_sdpa: bool = False,
5959
layer_norm_type: str = 'layer_norm',
60+
norm_eps: float = 1e-5,
6061
):
6162
"""
6263
Args:
@@ -107,7 +108,7 @@ def __init__(
107108
assert layer_norm_type in ['layer_norm', 'rms_norm']
108109
self.normalize_before = normalize_before
109110
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](output_size,
110-
eps=1e-5)
111+
eps=norm_eps)
111112
self.static_chunk_size = static_chunk_size
112113
self.use_dynamic_chunk = use_dynamic_chunk
113114
self.use_dynamic_left_chunk = use_dynamic_left_chunk
@@ -373,6 +374,7 @@ def __init__(
373374
use_sdpa: bool = False,
374375
mlp_type: str = 'position_wise_feed_forward',
375376
layer_norm_type: str = 'layer_norm',
377+
norm_eps: float = 1e-5,
376378
):
377379
""" Construct TransformerEncoder
378380
@@ -384,22 +386,24 @@ def __init__(
384386
input_layer, pos_enc_layer_type, normalize_before,
385387
static_chunk_size, use_dynamic_chunk, global_cmvn,
386388
use_dynamic_left_chunk, gradient_checkpointing,
387-
use_sdpa, layer_norm_type)
389+
use_sdpa, layer_norm_type, norm_eps)
388390
activation = WENET_ACTIVATION_CLASSES[activation_type]()
389391
mlp_class = WENET_MLP_CLASSES[mlp_type]
390392
self.encoders = torch.nn.ModuleList([
391-
TransformerEncoderLayer(output_size,
392-
WENET_ATTENTION_CLASSES["selfattn"](
393-
attention_heads, output_size,
394-
attention_dropout_rate, query_bias,
395-
key_bias, value_bias, use_sdpa),
396-
mlp_class(output_size, linear_units,
397-
dropout_rate, activation,
398-
mlp_bias),
399-
dropout_rate,
400-
normalize_before,
401-
layer_norm_type=layer_norm_type)
402-
for _ in range(num_blocks)
393+
TransformerEncoderLayer(
394+
output_size,
395+
WENET_ATTENTION_CLASSES["selfattn"](attention_heads,
396+
output_size,
397+
attention_dropout_rate,
398+
query_bias, key_bias,
399+
value_bias, use_sdpa),
400+
mlp_class(output_size, linear_units, dropout_rate, activation,
401+
mlp_bias),
402+
dropout_rate,
403+
normalize_before,
404+
layer_norm_type=layer_norm_type,
405+
norm_eps=norm_eps,
406+
) for _ in range(num_blocks)
403407
])
404408

405409

@@ -439,6 +443,8 @@ def __init__(
439443
gradient_checkpointing: bool = False,
440444
use_sdpa: bool = False,
441445
mlp_type: str = 'position_wise_feed_forward',
446+
layer_norm_type: str = 'layer_norm',
447+
norm_eps: float = 1e-5,
442448
):
443449
"""Construct ConformerEncoder
444450
@@ -463,7 +469,7 @@ def __init__(
463469
input_layer, pos_enc_layer_type, normalize_before,
464470
static_chunk_size, use_dynamic_chunk, global_cmvn,
465471
use_dynamic_left_chunk, gradient_checkpointing,
466-
use_sdpa)
472+
use_sdpa, layer_norm_type, norm_eps)
467473
activation = WENET_ACTIVATION_CLASSES[activation_type]()
468474

469475
# self-attention module definition
@@ -500,5 +506,7 @@ def __init__(
500506
*convolution_layer_args) if use_cnn_module else None,
501507
dropout_rate,
502508
normalize_before,
509+
layer_norm_type=layer_norm_type,
510+
norm_eps=norm_eps,
503511
) for _ in range(num_blocks)
504512
])

wenet/transformer/encoder_layer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,15 @@ def __init__(
4747
dropout_rate: float,
4848
normalize_before: bool = True,
4949
layer_norm_type: str = 'layer_norm',
50+
norm_eps: float = 1e-5,
5051
):
5152
"""Construct an EncoderLayer object."""
5253
super().__init__()
5354
self.self_attn = self_attn
5455
self.feed_forward = feed_forward
5556
assert layer_norm_type in ['layer_norm', 'rms_norm']
56-
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
57-
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
57+
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps)
58+
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps)
5859
self.dropout = nn.Dropout(dropout_rate)
5960
self.size = size
6061
self.normalize_before = normalize_before
@@ -140,6 +141,7 @@ def __init__(
140141
dropout_rate: float = 0.1,
141142
normalize_before: bool = True,
142143
layer_norm_type: str = 'layer_norm',
144+
norm_eps: float = 1e-5,
143145
):
144146
"""Construct an EncoderLayer object."""
145147
super().__init__()
@@ -149,20 +151,20 @@ def __init__(
149151
self.feed_forward_macaron = feed_forward_macaron
150152
self.conv_module = conv_module
151153
self.norm_ff = WENET_NORM_CLASSES[layer_norm_type](
152-
size, eps=1e-5) # for the FNN module
154+
size, eps=norm_eps) # for the FNN module
153155
self.norm_mha = WENET_NORM_CLASSES[layer_norm_type](
154-
size, eps=1e-5) # for the MHA module
156+
size, eps=norm_eps) # for the MHA module
155157
if feed_forward_macaron is not None:
156158
self.norm_ff_macaron = WENET_NORM_CLASSES[layer_norm_type](
157-
size, eps=1e-5)
159+
size, eps=norm_eps)
158160
self.ff_scale = 0.5
159161
else:
160162
self.ff_scale = 1.0
161163
if self.conv_module is not None:
162164
self.norm_conv = WENET_NORM_CLASSES[layer_norm_type](
163-
size, eps=1e-5) # for the CNN module
165+
size, eps=norm_eps) # for the CNN module
164166
self.norm_final = WENET_NORM_CLASSES[layer_norm_type](
165-
size, eps=1e-5) # for the final output of the block
167+
size, eps=norm_eps) # for the final output of the block
166168
self.dropout = nn.Dropout(dropout_rate)
167169
self.size = size
168170
self.normalize_before = normalize_before

0 commit comments

Comments
 (0)