@@ -795,7 +795,7 @@ def get_input_embeddings(self):
795795 return self .word_embeddings
796796
797797 def _prepare_attn_mask (
798- self , attention_mask : Tensor , input_shape : Tuple [int , int ], past_key_values_length : int , num_heads : int , dtype
798+ self , attention_mask : Tensor , input_shape : Tuple [int , int ], past_key_values_length : int , num_heads : int
799799 ) -> Tensor :
800800 # create causal mask
801801 # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
@@ -819,8 +819,9 @@ def _prepare_attn_mask(
819819
820820 mask_shape = expanded_attn_mask .shape
821821 expanded_attn_mask = expanded_attn_mask .expand ([mask_shape [0 ], num_heads , mask_shape [2 ], mask_shape [3 ]])
822- zero = paddle .zeros (expanded_attn_mask .shape , dtype = dtype )
823- neg_inf = paddle .full (expanded_attn_mask .shape , paddle .finfo (dtype ).min , dtype = dtype )
822+ # Attention score will be cast to float32 in the following calculation, therefore we set attention_mask dtype as float32
823+ zero = paddle .zeros (expanded_attn_mask .shape , dtype = paddle .float32 )
824+ neg_inf = paddle .full (expanded_attn_mask .shape , paddle .finfo (paddle .float32 ).min , dtype = paddle .float32 )
824825 expanded_attn_mask = paddle .where (expanded_attn_mask , zero , neg_inf )
825826 batch_size , num_heads , sq_len , kv_len = expanded_attn_mask .shape
826827 return expanded_attn_mask .reshape ([batch_size * num_heads , sq_len , kv_len ])
@@ -929,7 +930,6 @@ def forward(
929930 input_shape = (batch_size , seq_length ),
930931 past_key_values_length = past_key_values_length ,
931932 num_heads = block_size ,
932- dtype = hidden_states .dtype ,
933933 )
934934 else :
935935 alibi = alibi .reshape ([batch_size * self .config .n_head , 1 , seq_length_with_past ])
@@ -938,7 +938,6 @@ def forward(
938938 input_shape = (batch_size , seq_length ),
939939 past_key_values_length = past_key_values_length ,
940940 num_heads = self .config .n_head ,
941- dtype = hidden_states .dtype ,
942941 )
943942
944943 for i , (block , layer_past ) in enumerate (zip (self .h , past_key_values )):
@@ -1088,7 +1087,7 @@ def __init__(self, config):
10881087 self .lm_head = BloomLMHead (config , self .bloom .word_embeddings .weight )
10891088 self .criterion = BloomPretrainingCriterion (
10901089 tensor_parallel_degree = config .tensor_parallel_degree ,
1091- tensor_parallel_output = True ,
1090+ tensor_parallel_output = config . tensor_parallel_output ,
10921091 )
10931092
10941093 def get_output_embeddings (self ):
0 commit comments