Skip to content

Commit 5aac371

Browse files
NickLucchelulmer
authored andcommitted
[V1][TPU] Support V1 Sampler for ragged attention (vllm-project#14227)
Signed-off-by: NickLucche <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent de46047 commit 5aac371

File tree

6 files changed

+535
-55
lines changed

6 files changed

+535
-55
lines changed

tests/v1/tpu/test_sampler.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import tempfile
3+
from time import time
4+
5+
import pytest
6+
7+
from vllm import LLM, envs
8+
from vllm.platforms import current_platform
9+
from vllm.sampling_params import SamplingParams
10+
11+
if not envs.VLLM_USE_V1:
12+
pytest.skip(
13+
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
14+
allow_module_level=True,
15+
)
16+
17+
18+
@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"])
19+
@pytest.mark.skipif(not current_platform.is_tpu(),
20+
reason="This test needs a TPU")
21+
def test_sampler_compilation(model_name: str, monkeypatch):
22+
"""
23+
Check that no recompilation happens despite changing sampling parameters.
24+
We can't read XLA metrics from the engine process, hence we measure time.
25+
"""
26+
with tempfile.TemporaryDirectory() as temp_dir:
27+
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir)
28+
# Compiling model init may still take some time, enforce_eager to skip.
29+
llm = LLM(model_name,
30+
enforce_eager=True,
31+
max_num_seqs=16,
32+
max_model_len=1024,
33+
gpu_memory_utilization=0.5)
34+
prompts = [
35+
"A robot may not injure a human being",
36+
"It is only with the heart that one can see rightly;",
37+
]
38+
# First inference should be slow
39+
sampling_params = SamplingParams(
40+
temperature=0.7,
41+
# top_p=0.6, # TODO too slow!
42+
# top_k=10,
43+
min_p=0.2,
44+
max_tokens=16)
45+
s = time()
46+
_ = llm.generate(prompts, sampling_params)
47+
run1 = time() - s
48+
49+
# Second request with different params, but for which we
50+
# compiled for in previous eager iteration.
51+
sampling_params = SamplingParams(temperature=0.1,
52+
min_p=0.8,
53+
max_tokens=24)
54+
s = time()
55+
_ = llm.generate(prompts, sampling_params)
56+
run2 = time() - s
57+
# Much faster after compiling
58+
assert run1 * 0.1 > run2
59+
print("TIMES", run1, run2)
60+
61+
# Third request with min_p set to "None". It will not trigger
62+
# recompilation as a default 0 value will be used.
63+
sampling_params = SamplingParams(max_tokens=24, temperature=0.0)
64+
s = time()
65+
_ = llm.generate(prompts, sampling_params)
66+
run3 = time() - s
67+
assert run1 * 0.1 > run3
68+
print("TIMES", run1, run3)
69+
70+
71+
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
72+
@pytest.mark.skipif(not current_platform.is_tpu(),
73+
reason="This test needs a TPU")
74+
def test_sampler_different(model_name: str):
75+
"""
76+
Test significantly different sampling params to assert the model produces
77+
different results.
78+
"""
79+
llm = LLM(
80+
model_name,
81+
enforce_eager=True,
82+
max_num_seqs=1,
83+
max_model_len=64,
84+
# TODO: setting to 0.5 or it will go OOM
85+
gpu_memory_utilization=0.5)
86+
prompts = [
87+
"Write a short story about a robot that dreams for the first time."
88+
]
89+
sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64)
90+
output = llm.generate(prompts, sampling_params)
91+
92+
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
93+
output2 = llm.generate(prompts, sampling_params)
94+
assert output[0].outputs[0].text != output2[0].outputs[0].text

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def __init__(self):
6565
"native implementation of top-p & top-k sampling. For the "
6666
"best performance, please install FlashInfer.")
6767
self.forward = self.forward_native
68+
elif current_platform.is_tpu():
69+
self.forward = self.forward_tpu
6870
else:
6971
self.forward = self.forward_native
7072

@@ -96,6 +98,18 @@ def forward_cuda(
9698
return random_sample(probs, generators)
9799
return flashinfer_sample(probs, k, p, generators)
98100

101+
def forward_tpu(
102+
self,
103+
logits: torch.Tensor,
104+
generators: dict[int, torch.Generator],
105+
k: Optional[torch.Tensor],
106+
p: Optional[torch.Tensor],
107+
) -> torch.Tensor:
108+
# TODO Placeholder for TPU optimized topk/p kernel
109+
# logits = apply_top_k_top_p(logits, k, p)
110+
probs = logits.softmax(dim=-1, dtype=torch.float32)
111+
return random_sample(probs, generators)
112+
99113

100114
def apply_top_k_top_p(
101115
logits: torch.Tensor,
@@ -112,7 +126,7 @@ def apply_top_k_top_p(
112126

113127
if k is not None:
114128
# Apply top-k.
115-
top_k_mask = logits_sort.size(1) - k.to(torch.long)
129+
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
116130
# Get all the top_k values.
117131
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
118132
top_k_mask = logits_sort < top_k_mask

vllm/v1/sample/tpu/__init__.py

Whitespace-only changes.

vllm/v1/sample/tpu/metadata.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from dataclasses import dataclass, field
3+
from typing import Optional
4+
5+
import torch
6+
import torch_xla.core.xla_model as xm
7+
8+
from vllm.v1.sample.metadata import SamplingMetadata
9+
10+
11+
@dataclass
12+
class TPUSupportedSamplingMetadata:
13+
# This class exposes a more xla-friendly interface than SamplingMetadata
14+
# on TPU, in particular all arguments should be traceable and no optionals
15+
# are allowed, to avoid graph recompilation on Nones.
16+
temperature: torch.Tensor
17+
18+
min_p: torch.Tensor
19+
# Still too slow on forward_native!
20+
top_k: torch.Tensor = None
21+
top_p: torch.Tensor = None
22+
23+
# XLA-unfriendly control flow in Sampler
24+
all_greedy: bool = False
25+
all_random: bool = False
26+
# Greedy sampling flag for compiling single xla graph.
27+
do_argmax: torch.Tensor = None
28+
29+
# speculation not supported
30+
spec_token_ids = None
31+
32+
# Generator not supported by xla
33+
generators: dict[int,
34+
torch.Generator] = field(default_factory=lambda: dict())
35+
36+
# unsupported, you need to return an extra tensor of static size BxV
37+
max_num_logprobs = None
38+
39+
# TODO No penalties for now
40+
no_penalties: bool = True
41+
prompt_token_ids = None
42+
frequency_penalties = None
43+
presence_penalties = None
44+
repetition_penalties = None
45+
# should use tensor
46+
output_token_ids: list[list[int]] = field(default_factory=lambda: list())
47+
48+
min_tokens = None # impl is not vectorized
49+
50+
logit_bias: list[Optional[dict[int, float]]] = field(
51+
default_factory=lambda: list())
52+
53+
allowed_token_ids_mask = None
54+
bad_words_token_ids = None
55+
indices_do_sample: torch.Tensor = None
56+
57+
def __post_init__(self):
58+
temp = self.temperature
59+
if self.indices_do_sample is None:
60+
self.indices_do_sample = torch.zeros(temp.shape[0],
61+
device=temp.device,
62+
dtype=torch.int32)
63+
if self.do_argmax is None:
64+
self.do_argmax = torch.tensor(0,
65+
dtype=torch.bool,
66+
device=temp.device)
67+
68+
@classmethod
69+
def from_sampling_metadata(
70+
cls, metadata: SamplingMetadata,
71+
padded_do_sample_indices: torch.Tensor, num_do_sample: int,
72+
device: torch.device) -> "TPUSupportedSamplingMetadata":
73+
"""
74+
Create an XLA-frienly SamplingMetadata structure. Do so by first
75+
instantiating an object with fixed-sized tensors and then writing the
76+
values in input `metadata`. Do that only for non-None values so that
77+
recompilation is not triggered for optional values (None/torch.Tensor).
78+
79+
In order to handle different sizes for the params that range from 1 up
80+
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
81+
Same thing for `padded_do_sample_indices`, which contains the indices
82+
to be fed to the Sampler, padded to the closest pre-compiled shape.
83+
84+
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
85+
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
86+
"""
87+
metadata = cls._validate_sampling_metadata(metadata)
88+
# NOTE we have to initialize default tensor-based params first and
89+
# skip None values altogether to produce the same xla graph.
90+
num_samples = len(padded_do_sample_indices)
91+
do_argmax = torch.tensor(metadata.all_greedy,
92+
dtype=torch.bool,
93+
device=device)
94+
new_metadata = cls.get_default_sampling_params(num_samples, device,
95+
indices_do_sample=\
96+
padded_do_sample_indices,
97+
do_argmax=do_argmax
98+
)
99+
supported_params = \
100+
TPUSupportedSamplingMetadata._get_default_params_values()
101+
# Copy input non-None values into `new_metadata` fixed-sized tensors.
102+
for p_name in supported_params:
103+
old_val = getattr(metadata, p_name)
104+
new_val = getattr(new_metadata, p_name)
105+
if isinstance(old_val, torch.Tensor):
106+
new_val[:num_do_sample] = old_val
107+
setattr(new_metadata, p_name, new_val)
108+
109+
xm.mark_step()
110+
xm.wait_device_ops()
111+
return new_metadata
112+
113+
@classmethod
114+
def get_default_sampling_params(
115+
cls,
116+
num_samples: int,
117+
device: torch.device,
118+
indices_do_sample=None,
119+
do_argmax=None) -> "TPUSupportedSamplingMetadata":
120+
# As sampling happens on a single traced graph, options
121+
# are "disabled" by having them evaluate to an Identity op.
122+
# Note that initialization is dependent on num_samples.
123+
sampling_metadata_disable_value = \
124+
TPUSupportedSamplingMetadata._get_default_params_values()
125+
init_kwargs = dict()
126+
for p_name, (default_val,
127+
dtype) in sampling_metadata_disable_value.items():
128+
default_tensor = torch.full((num_samples, ),
129+
default_val,
130+
dtype=dtype,
131+
device=device)
132+
init_kwargs[p_name] = default_tensor
133+
134+
return cls(**init_kwargs,
135+
indices_do_sample=indices_do_sample,
136+
do_argmax=do_argmax)
137+
138+
@staticmethod
139+
def _validate_sampling_metadata(
140+
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
141+
if sampling_metadata.all_greedy:
142+
# Set to None since #13587. Make sure default isn't overruled.
143+
assert sampling_metadata.temperature is None
144+
return sampling_metadata
145+
146+
@staticmethod
147+
def _get_default_params_values():
148+
return dict(
149+
# Since #13587 greedy sampling requires branching off which leads
150+
# to separate graphs. We set temp to noop and handle argmax here.
151+
temperature=(1.0, torch.float32),
152+
min_p=(0.0, torch.float32),
153+
# strictly disabled for now
154+
# top_k=(-1, torch.int32),
155+
# top_p=(0.0, torch.float32),
156+
# frequency_penalties=(0.0, torch.float32),
157+
# presence_penalties=(0.0, torch.float32),
158+
# repetition_penalties=(0.0, torch.float32),
159+
)

0 commit comments

Comments
 (0)