2525from wenet .utils .class_utils import (
2626 WENET_EMB_CLASSES ,
2727 WENET_MLP_CLASSES ,
28+ WENET_NORM_CLASSES ,
2829 WENET_SUBSAMPLE_CLASSES ,
2930 WENET_ATTENTION_CLASSES ,
3031 WENET_ACTIVATION_CLASSES ,
@@ -55,6 +56,7 @@ def __init__(
5556 use_dynamic_left_chunk : bool = False ,
5657 gradient_checkpointing : bool = False ,
5758 use_sdpa : bool = False ,
59+ layer_norm_type : str = 'layer_norm' ,
5860 ):
5961 """
6062 Args:
@@ -102,8 +104,10 @@ def __init__(
102104 positional_dropout_rate ),
103105 )
104106
107+ assert layer_norm_type in ['layer_norm' , 'rms_norm' ]
105108 self .normalize_before = normalize_before
106- self .after_norm = torch .nn .LayerNorm (output_size , eps = 1e-5 )
109+ self .after_norm = WENET_NORM_CLASSES [layer_norm_type ](output_size ,
110+ eps = 1e-5 )
107111 self .static_chunk_size = static_chunk_size
108112 self .use_dynamic_chunk = use_dynamic_chunk
109113 self .use_dynamic_left_chunk = use_dynamic_left_chunk
@@ -368,6 +372,7 @@ def __init__(
368372 gradient_checkpointing : bool = False ,
369373 use_sdpa : bool = False ,
370374 mlp_type : str = 'position_wise_feed_forward' ,
375+ layer_norm_type : str = 'layer_norm' ,
371376 ):
372377 """ Construct TransformerEncoder
373378
@@ -379,19 +384,21 @@ def __init__(
379384 input_layer , pos_enc_layer_type , normalize_before ,
380385 static_chunk_size , use_dynamic_chunk , global_cmvn ,
381386 use_dynamic_left_chunk , gradient_checkpointing ,
382- use_sdpa )
387+ use_sdpa , layer_norm_type )
383388 activation = WENET_ACTIVATION_CLASSES [activation_type ]()
384389 mlp_class = WENET_MLP_CLASSES [mlp_type ]
385390 self .encoders = torch .nn .ModuleList ([
386- TransformerEncoderLayer (
387- output_size ,
388- WENET_ATTENTION_CLASSES ["selfattn" ](attention_heads ,
389- output_size ,
390- attention_dropout_rate ,
391- query_bias , key_bias ,
392- value_bias , use_sdpa ),
393- mlp_class (output_size , linear_units , dropout_rate , activation ,
394- mlp_bias ), dropout_rate , normalize_before )
391+ TransformerEncoderLayer (output_size ,
392+ WENET_ATTENTION_CLASSES ["selfattn" ](
393+ attention_heads , output_size ,
394+ attention_dropout_rate , query_bias ,
395+ key_bias , value_bias , use_sdpa ),
396+ mlp_class (output_size , linear_units ,
397+ dropout_rate , activation ,
398+ mlp_bias ),
399+ dropout_rate ,
400+ normalize_before ,
401+ layer_norm_type = layer_norm_type )
395402 for _ in range (num_blocks )
396403 ])
397404
0 commit comments