Skip to content

[Spec Decoding] Use target model max length as default for draft model #7706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
Expand Down Expand Up @@ -210,7 +211,8 @@ def __init__(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window())
sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=spec_target_max_model_len)
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = self._init_multimodal_config(
Expand Down Expand Up @@ -1134,6 +1136,7 @@ def maybe_create_spec_config(
code_revision=draft_code_revision,
tokenizer_revision=target_model_config.tokenizer_revision,
max_model_len=None,
spec_target_max_model_len=target_model_config.max_model_len,
quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager,
max_seq_len_to_capture=target_model_config.
Expand Down Expand Up @@ -1563,6 +1566,7 @@ def _get_and_verify_max_len(
max_model_len: Optional[int],
disable_sliding_window: bool,
sliding_window_len: Optional[int],
spec_target_max_model_len: Optional[int] = None,
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
Expand Down Expand Up @@ -1605,6 +1609,11 @@ def _get_and_verify_max_len(
# If max_model_len is specified, we use it.
return max_model_len

if spec_target_max_model_len is not None:
# If this is a speculative draft model, we use the max model len
# from the target model.
return spec_target_max_model_len

default_max_len = 2048
logger.warning(
"The model's config.json does not contain any of the following "
Expand Down
Loading