@@ -132,7 +132,10 @@ class PredictorArgument:
132132
133133 @property
134134 def total_max_length (self ):
135- return 8192 # Maximum sequence length.
135+ if self .device == "npu" :
136+ return self .src_length + self .max_length
137+ else :
138+ return 8192 # Maximum sequence length.
136139
137140
138141@dataclass
@@ -859,6 +862,35 @@ def init_model_inputs(self, config: PredictorArgument):
859862 self .model_inputs ["tgt_mask" ] = (
860863 alibi_decoder + (1 - self .model_inputs ["tgt_mask" ]) * paddle .finfo (self .dtype ).min
861864 ).cast (self .dtype )
865+ elif config .device == "npu" and self .model_config .get ("alibi" , False ):
866+ lower_one_tril = paddle .tril (
867+ paddle .ones (shape = (config .total_max_length , config .total_max_length ), dtype = self .dtype )
868+ )
869+ lower_one_tril = lower_one_tril [None , None , :, :]
870+ src_mask = lower_one_tril .tile ([config .batch_size , 1 , 1 , 1 ])
871+ tgt_mask = paddle .full (
872+ shape = [config .batch_size , 1 , 1 , config .total_max_length ], fill_value = 1 , dtype = self .dtype
873+ )
874+ arange_tensor_encoder = paddle .arange (config .total_max_length ).astype (self .dtype )
875+ alibi_slopes = llm_utils .get_alibi_slopes (self .num_attention_heads )
876+ alibi = alibi_slopes [None , :, None , None ] * arange_tensor_encoder
877+ alibi_encoder = alibi .tile ([config .batch_size , 1 , config .total_max_length , 1 ])
878+ alibi_decoder = alibi .tile (
879+ [
880+ config .batch_size ,
881+ 1 ,
882+ 1 ,
883+ 1 ,
884+ ]
885+ )
886+ # self.model_inputs["src_mask/tgt_mask"] is read only, will not be updated!
887+ src_mask = (
888+ alibi_encoder + (1 - src_mask ) * paddle .finfo (self .dtype ).min
889+ ).cast (self .dtype )
890+ tgt_mask = (
891+ alibi_decoder + (1 - tgt_mask ) * paddle .finfo (self .dtype ).min
892+ ).cast (self .dtype )
893+ self .model_inputs ["rope_emb" ] = paddle .concat ([src_mask .reshape ([- 1 ]), tgt_mask .reshape ([- 1 ])])
862894
863895 def _preprocess (self , input_text : list [str ]):
864896 if self .tokenizer .chat_template is not None :
0 commit comments