@@ -57,6 +57,7 @@ def __init__(
5757 gradient_checkpointing : bool = False ,
5858 use_sdpa : bool = False ,
5959 layer_norm_type : str = 'layer_norm' ,
60+ norm_eps : float = 1e-5 ,
6061 ):
6162 """
6263 Args:
@@ -107,7 +108,7 @@ def __init__(
107108 assert layer_norm_type in ['layer_norm' , 'rms_norm' ]
108109 self .normalize_before = normalize_before
109110 self .after_norm = WENET_NORM_CLASSES [layer_norm_type ](output_size ,
110- eps = 1e-5 )
111+ eps = norm_eps )
111112 self .static_chunk_size = static_chunk_size
112113 self .use_dynamic_chunk = use_dynamic_chunk
113114 self .use_dynamic_left_chunk = use_dynamic_left_chunk
@@ -373,6 +374,7 @@ def __init__(
373374 use_sdpa : bool = False ,
374375 mlp_type : str = 'position_wise_feed_forward' ,
375376 layer_norm_type : str = 'layer_norm' ,
377+ norm_eps : float = 1e-5 ,
376378 ):
377379 """ Construct TransformerEncoder
378380
@@ -384,22 +386,24 @@ def __init__(
384386 input_layer , pos_enc_layer_type , normalize_before ,
385387 static_chunk_size , use_dynamic_chunk , global_cmvn ,
386388 use_dynamic_left_chunk , gradient_checkpointing ,
387- use_sdpa , layer_norm_type )
389+ use_sdpa , layer_norm_type , norm_eps )
388390 activation = WENET_ACTIVATION_CLASSES [activation_type ]()
389391 mlp_class = WENET_MLP_CLASSES [mlp_type ]
390392 self .encoders = torch .nn .ModuleList ([
391- TransformerEncoderLayer (output_size ,
392- WENET_ATTENTION_CLASSES ["selfattn" ](
393- attention_heads , output_size ,
394- attention_dropout_rate , query_bias ,
395- key_bias , value_bias , use_sdpa ),
396- mlp_class (output_size , linear_units ,
397- dropout_rate , activation ,
398- mlp_bias ),
399- dropout_rate ,
400- normalize_before ,
401- layer_norm_type = layer_norm_type )
402- for _ in range (num_blocks )
393+ TransformerEncoderLayer (
394+ output_size ,
395+ WENET_ATTENTION_CLASSES ["selfattn" ](attention_heads ,
396+ output_size ,
397+ attention_dropout_rate ,
398+ query_bias , key_bias ,
399+ value_bias , use_sdpa ),
400+ mlp_class (output_size , linear_units , dropout_rate , activation ,
401+ mlp_bias ),
402+ dropout_rate ,
403+ normalize_before ,
404+ layer_norm_type = layer_norm_type ,
405+ norm_eps = norm_eps ,
406+ ) for _ in range (num_blocks )
403407 ])
404408
405409
@@ -439,6 +443,8 @@ def __init__(
439443 gradient_checkpointing : bool = False ,
440444 use_sdpa : bool = False ,
441445 mlp_type : str = 'position_wise_feed_forward' ,
446+ layer_norm_type : str = 'layer_norm' ,
447+ norm_eps : float = 1e-5 ,
442448 ):
443449 """Construct ConformerEncoder
444450
@@ -463,7 +469,7 @@ def __init__(
463469 input_layer , pos_enc_layer_type , normalize_before ,
464470 static_chunk_size , use_dynamic_chunk , global_cmvn ,
465471 use_dynamic_left_chunk , gradient_checkpointing ,
466- use_sdpa )
472+ use_sdpa , layer_norm_type , norm_eps )
467473 activation = WENET_ACTIVATION_CLASSES [activation_type ]()
468474
469475 # self-attention module definition
@@ -500,5 +506,7 @@ def __init__(
500506 * convolution_layer_args ) if use_cnn_module else None ,
501507 dropout_rate ,
502508 normalize_before ,
509+ layer_norm_type = layer_norm_type ,
510+ norm_eps = norm_eps ,
503511 ) for _ in range (num_blocks )
504512 ])
0 commit comments