Skip to content

[Core] Use flashinfer sampling kernel when available #7137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7e47c2d
Use flashinfer kernel to do sampling if available
peng1999 Aug 2, 2024
fa61264
Merge remote-tracking branch 'upstream/main' into opt-topk
peng1999 Aug 2, 2024
56beab0
Fix type mismatch
peng1999 Aug 5, 2024
5396c9d
Some renaming
peng1999 Aug 5, 2024
5999bd3
Fallback for flashinfer sampler
peng1999 Aug 5, 2024
420b004
Formatting fix
peng1999 Aug 5, 2024
98d372e
Tests fix
peng1999 Aug 5, 2024
0a8be18
Fix mypy
peng1999 Aug 5, 2024
f170646
Add test for flashinfer sampler
peng1999 Aug 5, 2024
88c8a98
Suppress yapf on import
peng1999 Aug 5, 2024
c404cd5
Fix pipeline
peng1999 Aug 5, 2024
c361a95
Change back to torch generator, add env flags
peng1999 Aug 6, 2024
8af0e09
Merge remote-tracking branch 'upstream/main' into opt-topk
peng1999 Aug 6, 2024
99f7ecc
rename env for flashinfer, rollback changes in utils
peng1999 Aug 7, 2024
7e03711
rollback changes to utils
peng1999 Aug 7, 2024
6416046
rename env
peng1999 Aug 8, 2024
fdc23a3
add top_k_top_p when fallback
peng1999 Aug 8, 2024
b97c911
Adapt flashinfer 0.1.4
peng1999 Aug 12, 2024
f8d7093
Revert changes to sampling_metadata
peng1999 Aug 12, 2024
2d7e5c3
Change flashinfer 0.1.2 to 0.1.4 in test
peng1999 Aug 12, 2024
20eee6a
Merge remote-tracking branch 'upstream/main' into opt-topk
peng1999 Aug 12, 2024
f893110
Disable flashinfer in GPTQ reproduce test
peng1999 Aug 15, 2024
e4cfcfc
Disable flashinfer sampler in distributed test
peng1999 Aug 15, 2024
c5194ec
Merge remote-tracking branch 'upstream/main' into opt-topk
peng1999 Aug 15, 2024
0ec8b61
Disable flashinfer sampler by default
peng1999 Aug 16, 2024
9eaea5c
Update vllm/envs.py
peng1999 Aug 17, 2024
18d59a1
Merge branch 'vllm-project:main' into opt-topk
peng1999 Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ steps:
- vllm/model_executor/layers
- vllm/sampling_metadata.py
- tests/samplers
command: pytest -v -s samplers
commands:
- pytest -v -s samplers
- VLLM_NO_FLASHINFER_SAMPLER=1 pytest -v -s samplers

- label: LogitsProcessor Test # 5min
mirror_hardwares: [amd]
Expand Down
5 changes: 4 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,10 @@ def mock_sample(probs, *args, **kwargs):
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
for prob in probs], None)

with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
# top-k and top-p is only calculated when flashinfer kernel is not available
with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
patch("vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling", None):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

assert sample_probs is not None
Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_NO_FLASHINFER_SAMPLER: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
Expand Down Expand Up @@ -250,6 +251,10 @@ def get_default_config_root():
"VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),

# If set, vllm will not use flashinfer sampler
"VLLM_NO_FLASHINFER_SAMPLER":
lambda: bool(os.getenv("VLLM_NO_FLASHINFER_SAMPLER", 0)),

# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
Expand Down
109 changes: 83 additions & 26 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""A layer that samples the next tokens from the model's outputs."""
import itertools
import warnings
from importlib.util import find_spec
from math import inf
from typing import Dict, List, Optional, Tuple

Expand All @@ -11,6 +13,7 @@
if HAS_TRITON:
from vllm.model_executor.layers.ops.sample import sample as sample_triton

import vllm.envs as envs
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors,
SequenceGroupToSample)
Expand All @@ -19,6 +22,15 @@
PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceOutput)

if not envs.VLLM_NO_FLASHINFER_SAMPLER and find_spec("flashinfer"):
# yapf: disable
from flashinfer.sampling import (
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)

# yapf: enable
else:
flashinfer_top_k_top_p_sampling = None

# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]

Expand Down Expand Up @@ -121,7 +133,7 @@ def forward(
# Use in-place division to avoid creating a new tensor.
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k:
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)

Expand Down Expand Up @@ -473,34 +485,65 @@ def _multinomial(
probs: torch.Tensor,
num_samples: int,
seq_groups: Optional[List[SequenceGroupToSample]] = None,
is_fallback: bool = False,
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
# This allows us to do sampling with replacement by creating
# num_samples copies of each row in the tensor, and then
# batch sampling the resulting tensor.
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
if num_samples > 1 and not is_fallback:
probs = probs.repeat_interleave(num_samples, dim=0)
q = torch.empty_like(probs)
if seq_groups is None:
q.exponential_()
else:
sample_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
next_sample_idx = sample_idx + len(seq_ids) * num_samples
q[sample_idx:next_sample_idx].exponential_(
generator=seq_group.generator)
sample_idx = next_sample_idx
stride = len(seq_ids) * num_samples
assert seq_group.generator is not None
q[sample_idx:sample_idx +
stride].exponential_(generator=seq_group.generator)
sample_idx += stride
return probs.div_(q).argmax(dim=1).view(-1, num_samples)


def _top_k_top_p_multinomial_with_flashinfer(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
max_top_k_round = 32
if num_samples > 1:
probs = probs.repeat_interleave(num_samples, dim=0)
top_ks = top_ks.repeat_interleave(num_samples)
top_ps = top_ps.repeat_interleave(num_samples)
batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if seq_groups is None:
uniform_samples.random_()
else:
sample_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
stride = len(seq_ids) * num_samples
assert seq_group.generator is not None
uniform_samples[:, sample_idx:sample_idx +
stride].random_(generator=seq_group.generator)
sample_idx += stride
batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
probs,
uniform_samples,
top_ks,
top_ps,
)
if not success.all():
warnings.warn("Sampling with FlashInfer failed, fallback.",
stacklevel=2)
return _multinomial(probs, num_samples, seq_groups, is_fallback=True)
return batch_next_token_ids.view(-1, num_samples)


def _sample_with_torch(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
Expand Down Expand Up @@ -563,18 +606,28 @@ def _sample_with_torch(
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
"seq_groups": seq_groups,
}

multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices], max_best_of_in_batch,
**seeded_args)
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
seq_groups)

if flashinfer_top_k_top_p_sampling is not None:
multinomial_samples[
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
probs[long_sample_indices],
sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices],
max_best_of_in_batch,
seq_groups_arg,
)
else:
multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices],
max_best_of_in_batch,
seq_groups=seq_groups_arg)

if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
long_sample_indices] = multinomial_samples[sampling_type]
sampled_token_ids_tensor[long_sample_indices] = \
multinomial_samples[sampling_type].to(torch.long)

elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
Expand Down Expand Up @@ -692,9 +745,12 @@ def _sample_with_triton_kernel(


def _sample(
probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
"""
Args:
Expand All @@ -712,6 +768,7 @@ def _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs,
)
Expand Down
11 changes: 5 additions & 6 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,12 @@ def _prepare_seq_groups(
prompt_logprob_indices: List[int] = []
sample_indices: List[int] = []
do_sample = seq_group_metadata.do_sample
seed = sampling_params.seed

if seq_group_metadata.is_prompt:
if sampling_params.seed is not None:
generator = torch.Generator(device=device).manual_seed(
sampling_params.seed)
if generators is not None:
generators[seq_group_metadata.request_id] = generator
if seed is not None and generators is not None:
generator = torch.Generator(device=device).manual_seed(seed)
generators[seq_group_metadata.request_id] = generator

num_prompts += 1
num_prefill_sample = len(seq_ids)
Expand All @@ -243,7 +242,7 @@ def _prepare_seq_groups(
prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0

if sampling_params.seed is not None and generators is not None:
if seed is not None and generators is not None:
generator = generators.get(seq_group_metadata.request_id)

# Update indices to select from the model output.
Expand Down
17 changes: 12 additions & 5 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,8 @@ def make_tensor_with_pad(
pad: T,
dtype: torch.dtype,
*,
device: Union[str, torch.device],
max_len: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Expand All @@ -679,11 +679,18 @@ def make_tensor_with_pad(
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)

tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory()
return async_numpy_to_tensor(padded_x, device)

return tensor

def async_numpy_to_tensor(x: npt.NDArray, device: Union[str, torch.device]):
"""
Make a tensor from a numpy array asynchronously. Use pinned memory
if possible.
"""
t = torch.from_numpy(x)
if is_pin_memory_available():
t = t.pin_memory()
return t.to(device, non_blocking=True)


def async_tensor_h2d(
Expand Down
Loading