@@ -96,7 +96,7 @@ def __init__(
96
96
head_size : int ,
97
97
rotary_dim : int ,
98
98
max_position_embeddings : int ,
99
- base : int ,
99
+ base : float ,
100
100
is_neox_style : bool ,
101
101
dtype : torch .dtype ,
102
102
) -> None :
@@ -113,7 +113,7 @@ def __init__(
113
113
self .cos_sin_cache : torch .Tensor
114
114
self .register_buffer ("cos_sin_cache" , cache , persistent = False )
115
115
116
- def _compute_inv_freq (self , base : Union [ int , float ] ) -> torch .Tensor :
116
+ def _compute_inv_freq (self , base : float ) -> torch .Tensor :
117
117
"""Compute the inverse frequency."""
118
118
# NOTE(woosuk): To exactly match the HF implementation, we need to
119
119
# use CPU to compute the cache and then move it to GPU. However, we
@@ -404,7 +404,7 @@ def __init__(
404
404
head_size : int ,
405
405
rotary_dim : int ,
406
406
max_position_embeddings : int ,
407
- base : int ,
407
+ base : float ,
408
408
is_neox_style : bool ,
409
409
scaling_factors : Union [list [float ], float ],
410
410
dtype : torch .dtype ,
@@ -464,7 +464,7 @@ def __init__(self,
464
464
head_size : int ,
465
465
rotary_dim : int ,
466
466
max_position_embeddings : int ,
467
- base : int ,
467
+ base : float ,
468
468
is_neox_style : bool ,
469
469
scaling_factor : float ,
470
470
dtype : torch .dtype ,
@@ -474,7 +474,7 @@ def __init__(self,
474
474
super ().__init__ (head_size , rotary_dim , max_position_embeddings , base ,
475
475
is_neox_style , dtype )
476
476
477
- def _compute_inv_freq (self , base : Union [ int , float ] ) -> torch .Tensor :
477
+ def _compute_inv_freq (self , base : float ) -> torch .Tensor :
478
478
base = self .base * (self .scaling_factor if self .mixed_b is None else 1 )
479
479
inv_freq = super ()._compute_inv_freq (base )
480
480
@@ -501,7 +501,7 @@ def __init__(
501
501
head_size : int ,
502
502
rotary_dim : int ,
503
503
max_position_embeddings : int ,
504
- base : int ,
504
+ base : float ,
505
505
is_neox_style : bool ,
506
506
scaling_factor : float ,
507
507
dtype : torch .dtype ,
@@ -582,7 +582,7 @@ def __init__(
582
582
head_size : int ,
583
583
rotary_dim : int ,
584
584
max_position_embeddings : int ,
585
- base : int ,
585
+ base : float ,
586
586
is_neox_style : bool ,
587
587
scaling_factor : float ,
588
588
dtype : torch .dtype ,
@@ -644,7 +644,7 @@ def __init__(
644
644
rotary_dim : int ,
645
645
max_position_embeddings : int ,
646
646
original_max_position_embeddings : int ,
647
- base : int ,
647
+ base : float ,
648
648
is_neox_style : bool ,
649
649
dtype : torch .dtype ,
650
650
short_factor : list [float ],
@@ -769,7 +769,7 @@ def __init__(
769
769
head_size : int ,
770
770
rotary_dim : int ,
771
771
max_position_embeddings : int ,
772
- base : int ,
772
+ base : float ,
773
773
is_neox_style : bool ,
774
774
scaling_factor : float ,
775
775
dtype : torch .dtype ,
@@ -877,7 +877,7 @@ def __init__(
877
877
head_size : int ,
878
878
rotary_dim : int ,
879
879
max_position_embeddings : int ,
880
- base : int ,
880
+ base : float ,
881
881
is_neox_style : bool ,
882
882
dtype : torch .dtype ,
883
883
scaling_factor : float ,
@@ -892,7 +892,7 @@ def __init__(
892
892
super ().__init__ (head_size , rotary_dim , max_position_embeddings , base ,
893
893
is_neox_style , dtype )
894
894
895
- def _compute_inv_freq (self , base : Union [ int , float ] ) -> torch .Tensor :
895
+ def _compute_inv_freq (self , base : float ) -> torch .Tensor :
896
896
inv_freqs = super ()._compute_inv_freq (base )
897
897
low_freq_wavelen = self .orig_max_position / self .low_freq_factor
898
898
high_freq_wavelen = self .orig_max_position / self .high_freq_factor
@@ -923,14 +923,14 @@ def __init__(
923
923
head_size : int ,
924
924
rotary_dim : int ,
925
925
max_position_embeddings : int ,
926
- base : int ,
926
+ base : float ,
927
927
is_neox_style : bool ,
928
928
dtype : torch .dtype ,
929
929
):
930
930
super ().__init__ (head_size , rotary_dim , max_position_embeddings , base ,
931
931
is_neox_style , dtype )
932
932
933
- def _compute_inv_freq (self , base : Union [ int , float ] ) -> torch .Tensor :
933
+ def _compute_inv_freq (self , base : float ) -> torch .Tensor :
934
934
inv_freqs = super ()._compute_inv_freq (base )
935
935
inv_freqs = inv_freqs [:(self .rotary_dim // 2 )]
936
936
return inv_freqs
@@ -989,7 +989,7 @@ def __init__(
989
989
head_size : int ,
990
990
rotary_dim : int ,
991
991
max_position_embeddings : int ,
992
- base : int ,
992
+ base : float ,
993
993
is_neox_style : bool ,
994
994
dtype : torch .dtype ,
995
995
mrope_section : Optional [list [int ]] = None ,
@@ -1529,7 +1529,7 @@ def __init__(
1529
1529
head_size : int ,
1530
1530
rotary_dim : int ,
1531
1531
max_position_embeddings : int ,
1532
- base : int ,
1532
+ base : float ,
1533
1533
is_neox_style : bool ,
1534
1534
dtype : torch .dtype ,
1535
1535
chunk_size : int ,
@@ -1558,7 +1558,7 @@ def __init__(
1558
1558
q_inter_cache ,
1559
1559
persistent = False )
1560
1560
1561
- def _compute_inv_freq (self , base : Union [ int , float ] ) -> torch .Tensor :
1561
+ def _compute_inv_freq (self , base : float ) -> torch .Tensor :
1562
1562
"""Compute the inverse frequency."""
1563
1563
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
1564
1564
# However, we use `torch.arange(..., dtype=torch.float)` instead to
@@ -1705,7 +1705,7 @@ def get_rope(
1705
1705
head_size : int ,
1706
1706
rotary_dim : int ,
1707
1707
max_position : int ,
1708
- base : int ,
1708
+ base : float ,
1709
1709
is_neox_style : bool = True ,
1710
1710
rope_scaling : Optional [dict [str , Any ]] = None ,
1711
1711
dtype : Optional [torch .dtype ] = None ,
0 commit comments