@@ -419,11 +419,14 @@ def forward(self, x, seq_len=None):
419419 return (
420420 cos .cast (x .dtype ) if cos .dtype != x .dtype else cos ,
421421 sin .cast (x .dtype ) if sin .dtype != x .dtype else sin ,
422- self .cos_sin_table .cast (x .dtype )
423- if self .cos_sin_table is not None and self .cos_sin_table .dtype != x .dtype
424- else self .cos_sin_table ,
425422 )
426423
424+ def get_fused_cos_sin (self , x , seq_len = None ):
425+ if self .cos_sin_table is not None and self .cos_sin_table .dtype != x .dtype :
426+ return self .cos_sin_table .cast (x .dtype )
427+ else :
428+ return self .cos_sin_table
429+
427430
428431class LlamaLinearScalingRotaryEmbedding (LlamaRotaryEmbedding ):
429432 def __init__ (self , dim , max_position_embeddings = 2048 , base = 10000 , scaling_factor = 1.0 ):
@@ -482,19 +485,26 @@ def _scale_cos_sin(self, seq_len):
482485 def forward (self , x , seq_len = None ):
483486 # x: [bs, num_attention_heads, seq_len, head_size]
484487 if seq_len > self .max_position_embeddings :
485- scale_cos , scale_sin , scale_cos_sin = self ._scale_cos_sin (seq_len = seq_len )
488+ scale_cos , scale_sin , _ = self ._scale_cos_sin (seq_len = seq_len )
486489 else :
487- scale_cos , scale_sin , scale_cos_sin = self .cos_cached , self .sin_cached , self . cos_sin_table
490+ scale_cos , scale_sin = self .cos_cached , self .sin_cached
488491 cos = scale_cos [:, :seq_len , :, ...]
489492 sin = scale_sin [:, :seq_len , :, ...]
490493 return (
491494 cos .cast (x .dtype ) if cos .dtype != x .dtype else cos ,
492495 sin .cast (x .dtype ) if sin .dtype != x .dtype else sin ,
493- scale_cos_sin .cast (x .dtype )
494- if scale_cos_sin is not None and scale_cos_sin .dtype != x .dtype
495- else scale_cos_sin ,
496496 )
497497
498+ def get_fused_cos_sin (self , x , seq_len = None ):
499+ if seq_len > self .max_position_embeddings :
500+ _ , _ , scale_cos_sin = self ._scale_cos_sin (seq_len = seq_len )
501+ else :
502+ scale_cos_sin = self .cos_sin_table
503+ if scale_cos_sin is not None and scale_cos_sin .dtype != x .dtype :
504+ return scale_cos_sin .cast (x .dtype )
505+ else :
506+ return scale_cos_sin
507+
498508
499509def rotate_half (x ):
500510 """Rotates half the hidden dims of the input."""
@@ -943,7 +953,7 @@ def forward(
943953 sin .cast (value_states .dtype ) if sin .dtype != value_states .dtype else sin ,
944954 )
945955 else :
946- cos , sin , _ = self .rotary_emb (value_states , seq_len = kv_seq_len )
956+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
947957
948958 query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
949959
0 commit comments