Skip to content

Avoid multiple instantiations of the RoPE class #1828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
return cache


_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}


def get_rope(
head_size: int,
rotary_dim: int,
Expand All @@ -280,6 +283,10 @@ def get_rope(
is_neox_style: bool,
rope_scaling: Optional[Dict[str, Any]],
) -> RotaryEmbedding:
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style)
Expand Down Expand Up @@ -312,4 +319,5 @@ def get_rope(
**extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb