Skip to content

Commit 65271eb

Browse files
DarkLight1337amitm02
authored andcommitted
[Misc] Update type annotation for rotary embedding base (vllm-project#18914)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: amit <[email protected]>
1 parent 47a5b2f commit 65271eb

File tree

4 files changed

+23
-26
lines changed

4 files changed

+23
-26
lines changed

benchmarks/kernels/benchmark_rope.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def benchmark_rope_kernels_multi_lora(
2222
seed: int,
2323
device: str,
2424
max_position: int = 8192,
25-
base: int = 10000,
25+
base: float = 10000,
2626
) -> None:
2727
current_platform.seed_everything(seed)
2828
torch.set_default_device(device)

tests/kernels/core/test_pos_encoding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_rotary_embedding(
7070
device: str,
7171
use_key: bool,
7272
max_position: int = 8192,
73-
base: int = 10000,
73+
base: float = 10000,
7474
) -> None:
7575
if rotary_dim is None:
7676
rotary_dim = head_size
@@ -135,7 +135,7 @@ def test_batched_rotary_embedding(
135135
device: str,
136136
use_key: bool,
137137
max_position: int = 8192,
138-
base: int = 10000,
138+
base: float = 10000,
139139
) -> None:
140140
current_platform.seed_everything(seed)
141141
torch.set_default_device(device)
@@ -203,7 +203,7 @@ def test_batched_rotary_embedding_multi_lora(
203203
device: str,
204204
use_key: bool,
205205
max_position: int = 8192,
206-
base: int = 10000,
206+
base: float = 10000,
207207
) -> None:
208208
current_platform.seed_everything(seed)
209209
torch.set_default_device(device)

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(
9696
head_size: int,
9797
rotary_dim: int,
9898
max_position_embeddings: int,
99-
base: int,
99+
base: float,
100100
is_neox_style: bool,
101101
dtype: torch.dtype,
102102
) -> None:
@@ -113,7 +113,7 @@ def __init__(
113113
self.cos_sin_cache: torch.Tensor
114114
self.register_buffer("cos_sin_cache", cache, persistent=False)
115115

116-
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
116+
def _compute_inv_freq(self, base: float) -> torch.Tensor:
117117
"""Compute the inverse frequency."""
118118
# NOTE(woosuk): To exactly match the HF implementation, we need to
119119
# use CPU to compute the cache and then move it to GPU. However, we
@@ -404,7 +404,7 @@ def __init__(
404404
head_size: int,
405405
rotary_dim: int,
406406
max_position_embeddings: int,
407-
base: int,
407+
base: float,
408408
is_neox_style: bool,
409409
scaling_factors: Union[list[float], float],
410410
dtype: torch.dtype,
@@ -464,7 +464,7 @@ def __init__(self,
464464
head_size: int,
465465
rotary_dim: int,
466466
max_position_embeddings: int,
467-
base: int,
467+
base: float,
468468
is_neox_style: bool,
469469
scaling_factor: float,
470470
dtype: torch.dtype,
@@ -474,7 +474,7 @@ def __init__(self,
474474
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
475475
is_neox_style, dtype)
476476

477-
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
477+
def _compute_inv_freq(self, base: float) -> torch.Tensor:
478478
base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
479479
inv_freq = super()._compute_inv_freq(base)
480480

@@ -501,7 +501,7 @@ def __init__(
501501
head_size: int,
502502
rotary_dim: int,
503503
max_position_embeddings: int,
504-
base: int,
504+
base: float,
505505
is_neox_style: bool,
506506
scaling_factor: float,
507507
dtype: torch.dtype,
@@ -582,7 +582,7 @@ def __init__(
582582
head_size: int,
583583
rotary_dim: int,
584584
max_position_embeddings: int,
585-
base: int,
585+
base: float,
586586
is_neox_style: bool,
587587
scaling_factor: float,
588588
dtype: torch.dtype,
@@ -644,7 +644,7 @@ def __init__(
644644
rotary_dim: int,
645645
max_position_embeddings: int,
646646
original_max_position_embeddings: int,
647-
base: int,
647+
base: float,
648648
is_neox_style: bool,
649649
dtype: torch.dtype,
650650
short_factor: list[float],
@@ -769,7 +769,7 @@ def __init__(
769769
head_size: int,
770770
rotary_dim: int,
771771
max_position_embeddings: int,
772-
base: int,
772+
base: float,
773773
is_neox_style: bool,
774774
scaling_factor: float,
775775
dtype: torch.dtype,
@@ -877,7 +877,7 @@ def __init__(
877877
head_size: int,
878878
rotary_dim: int,
879879
max_position_embeddings: int,
880-
base: int,
880+
base: float,
881881
is_neox_style: bool,
882882
dtype: torch.dtype,
883883
scaling_factor: float,
@@ -892,7 +892,7 @@ def __init__(
892892
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
893893
is_neox_style, dtype)
894894

895-
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
895+
def _compute_inv_freq(self, base: float) -> torch.Tensor:
896896
inv_freqs = super()._compute_inv_freq(base)
897897
low_freq_wavelen = self.orig_max_position / self.low_freq_factor
898898
high_freq_wavelen = self.orig_max_position / self.high_freq_factor
@@ -923,14 +923,14 @@ def __init__(
923923
head_size: int,
924924
rotary_dim: int,
925925
max_position_embeddings: int,
926-
base: int,
926+
base: float,
927927
is_neox_style: bool,
928928
dtype: torch.dtype,
929929
):
930930
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
931931
is_neox_style, dtype)
932932

933-
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
933+
def _compute_inv_freq(self, base: float) -> torch.Tensor:
934934
inv_freqs = super()._compute_inv_freq(base)
935935
inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
936936
return inv_freqs
@@ -989,7 +989,7 @@ def __init__(
989989
head_size: int,
990990
rotary_dim: int,
991991
max_position_embeddings: int,
992-
base: int,
992+
base: float,
993993
is_neox_style: bool,
994994
dtype: torch.dtype,
995995
mrope_section: Optional[list[int]] = None,
@@ -1529,7 +1529,7 @@ def __init__(
15291529
head_size: int,
15301530
rotary_dim: int,
15311531
max_position_embeddings: int,
1532-
base: int,
1532+
base: float,
15331533
is_neox_style: bool,
15341534
dtype: torch.dtype,
15351535
chunk_size: int,
@@ -1558,7 +1558,7 @@ def __init__(
15581558
q_inter_cache,
15591559
persistent=False)
15601560

1561-
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
1561+
def _compute_inv_freq(self, base: float) -> torch.Tensor:
15621562
"""Compute the inverse frequency."""
15631563
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
15641564
# However, we use `torch.arange(..., dtype=torch.float)` instead to
@@ -1705,7 +1705,7 @@ def get_rope(
17051705
head_size: int,
17061706
rotary_dim: int,
17071707
max_position: int,
1708-
base: int,
1708+
base: float,
17091709
is_neox_style: bool = True,
17101710
rope_scaling: Optional[dict[str, Any]] = None,
17111711
dtype: Optional[torch.dtype] = None,

vllm/model_executor/models/minimax_text_01.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
head_size: int,
142142
rotary_dim: int,
143143
max_position: int,
144-
base: int,
144+
base: float,
145145
is_neox_style: bool,
146146
cache_dtype: torch.dtype,
147147
) -> None:
@@ -155,10 +155,7 @@ def __init__(
155155
cache = self._compute_cos_sin_cache().to(cache_dtype)
156156
self.register_buffer("cos_sin_cache", cache, persistent=False)
157157

158-
def _compute_inv_freq(
159-
self,
160-
base: Union[int, float],
161-
) -> torch.Tensor:
158+
def _compute_inv_freq(self, base: float) -> torch.Tensor:
162159
"""Compute the inverse frequency."""
163160
inv_freq = 1.0 / (base**(torch.arange(
164161
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))

0 commit comments

Comments
 (0)