Skip to content

Change top_k to be disabled with 0 (still accept -1 for now) #17773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def test_sampler_mixed(seed: int, device: str):
sampling_params = SamplingParams(
temperature=random.random() + 0.1,
top_p=min(random.random() + 0.1, 1),
top_k=random.randint(0, 10) or -1,
top_k=random.randint(0, 10),
n=n,
presence_penalty=random.randint(0, 1),
)
Expand Down
6 changes: 3 additions & 3 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"top_k": 0,
"min_p": 0.0,
}

Expand Down Expand Up @@ -853,7 +853,7 @@ class CompletionRequest(OpenAIBaseModel):
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"top_k": 0,
"min_p": 0.0,
}

Expand Down Expand Up @@ -1679,7 +1679,7 @@ class TranscriptionRequest(OpenAIBaseModel):
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"top_k": 0,
"min_p": 0.0,
}

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def from_sampling_metadata(

# k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size)
top_k = vocab_size if top_k == -1 else top_k
top_k = vocab_size if top_k < 1 else top_k
if temperature < _SAMPLING_EPS:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
Expand Down
13 changes: 7 additions & 6 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class SamplingParams(
top_p: Float that controls the cumulative probability of the top tokens
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
to 0 (or -1) to consider all tokens.
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
Expand Down Expand Up @@ -209,7 +209,7 @@ class SamplingParams(
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
top_k: int = 0
min_p: float = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
Expand Down Expand Up @@ -256,7 +256,7 @@ def from_optional(
repetition_penalty: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
top_p: Optional[float] = 1.0,
top_k: int = -1,
top_k: int = 0,
min_p: float = 0.0,
seed: Optional[int] = None,
stop: Optional[Union[str, list[str]]] = None,
Expand Down Expand Up @@ -376,7 +376,7 @@ def __post_init__(self) -> None:
if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self.top_p = 1.0
self.top_k = -1
self.top_k = 0
self.min_p = 0.0
self._verify_greedy_sampling()

Expand Down Expand Up @@ -404,8 +404,9 @@ def _verify_args(self) -> None:
f"temperature must be non-negative, got {self.temperature}.")
if not 0.0 < self.top_p <= 1.0:
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
# quietly accept -1 as disabled, but prefer 0
if self.top_k < -1:
raise ValueError(f"top_k must be 0 (disable), or at least 1, "
f"got {self.top_k}.")
if not isinstance(self.top_k, int):
raise TypeError(
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def _convert_to_neuron_sampling_params(
if temperature == 0.0:
# Enable greedy sampling on zero temperature
return (1, 1.0, 1.0)
if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
if top_k < 1 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
top_k = self._MAX_NEURON_SAMPLING_TOP_K

return (top_k, top_p, temperature)
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def _prepare_sample(
"Top-p sampling is currently disabled for the TPU backend "
"due to performance issues.")
p.append(sampling_params.top_p)
if sampling_params.top_k != -1:
if sampling_params.top_k > 0:
raise NotImplementedError(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues.")
Expand Down