@@ -134,7 +134,10 @@ class PredictorArgument:
134134
135135 @property
136136 def total_max_length (self ):
137- return 8192 # Maximum sequence length.
137+ if self .device == "npu" :
138+ return self .src_length + self .max_length
139+ else :
140+ return 8192 # Maximum sequence length.
138141
139142
140143@dataclass
@@ -861,6 +864,35 @@ def init_model_inputs(self, config: PredictorArgument):
861864 self .model_inputs ["tgt_mask" ] = (
862865 alibi_decoder + (1 - self .model_inputs ["tgt_mask" ]) * paddle .finfo (self .dtype ).min
863866 ).cast (self .dtype )
867+ elif config .device == "npu" and self .model_config .get ("alibi" , False ):
868+ lower_one_tril = paddle .tril (
869+ paddle .ones (shape = (config .total_max_length , config .total_max_length ), dtype = self .dtype )
870+ )
871+ lower_one_tril = lower_one_tril [None , None , :, :]
872+ src_mask = lower_one_tril .tile ([config .batch_size , 1 , 1 , 1 ])
873+ tgt_mask = paddle .full (
874+ shape = [config .batch_size , 1 , 1 , config .total_max_length ], fill_value = 1 , dtype = self .dtype
875+ )
876+ arange_tensor_encoder = paddle .arange (config .total_max_length ).astype (self .dtype )
877+ alibi_slopes = llm_utils .get_alibi_slopes (self .num_attention_heads )
878+ alibi = alibi_slopes [None , :, None , None ] * arange_tensor_encoder
879+ alibi_encoder = alibi .tile ([config .batch_size , 1 , config .total_max_length , 1 ])
880+ alibi_decoder = alibi .tile (
881+ [
882+ config .batch_size ,
883+ 1 ,
884+ 1 ,
885+ 1 ,
886+ ]
887+ )
888+ # self.model_inputs["src_mask/tgt_mask"] is read only, will not be updated!
889+ src_mask = (
890+ alibi_encoder + (1 - src_mask ) * paddle .finfo (self .dtype ).min
891+ ).cast (self .dtype )
892+ tgt_mask = (
893+ alibi_decoder + (1 - tgt_mask ) * paddle .finfo (self .dtype ).min
894+ ).cast (self .dtype )
895+ self .model_inputs ["rope_emb" ] = paddle .concat ([src_mask .reshape ([- 1 ]), tgt_mask .reshape ([- 1 ])])
864896
865897 def _preprocess (self , input_text : list [str ]):
866898 if self .tokenizer .chat_template is not None :
0 commit comments