-
-
Notifications
You must be signed in to change notification settings - Fork 9.2k
[Speculative Decoding] EAGLE Implementation with Top-1 proposer #6830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 13 commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
3c590b1
initial changes to support EAGLE
abhigoyal1997 5f5bed1
handling hidden_states in case of bonus tokens since EAGLE will need it
abhigoyal1997 023e72d
enabling CUDA graph
abhigoyal1997 8ac1570
adding E2E test and formatting
abhigoyal1997 b379948
minor bug fix in graph capture
abhigoyal1997 aef9c00
fixing broadcasting of hidden states in distributed worker
abhigoyal1997 c8d63bd
formatting
abhigoyal1997 733ca4f
Merge branch 'main' of github.fkinternal.com:abhinav-goyal/vllm into …
abhigoyal1997 1a0aa60
formatting
abhigoyal1997 83b3dd8
Merge branch 'main' of github.fkinternal.com:abhinav-goyal/vllm into …
abhigoyal1997 b1f05ac
Masking position=0 in inputs for EAGLE
abhigoyal1997 bdee07c
reformatting
abhigoyal1997 441374f
Fixing the order of execution for scorer and proposer in non-driver w…
abhigoyal1997 0d1cbae
Adding hidden state propagation to _execute_model_spmd
abhigoyal1997 b60384a
Adding CUDA graph tests for medusa and eagle. Renaming mlp to medusa …
abhigoyal1997 7b6a0e6
Moving hidden states shift to spec_decode_worker
abhigoyal1997 9d806b3
formatting
abhigoyal1997 e1e3175
Merge branch 'vllm-project:main' into eagle
abhigoyal1997 8db174f
Adding vocab truncation to EAGLE
abhigoyal1997 b6b0548
Minor changes and fixes. Adding expand model request and hidden state…
abhigoyal1997 f9cbd49
Merge branch 'vllm-project:main' into eagle
abhigoyal1997 89184a1
Merge branch 'main' into eagle
abhigoyal1997 cf8b685
Merge branch 'vllm-project:main' into eagle
abhigoyal1997 3c24f4b
Merge branch 'vllm-project:main' into eagle
abhigoyal1997 a94ea89
Merge branch 'main' into eagle
abhigoyal1997 eaa586c
Removing commented code and a minor comment fix
abhigoyal1997 38e2b5c
formatting
abhigoyal1997 2f17900
Merge branch 'vllm-project:main' into eagle
abhigoyal1997 c5f8d15
adding comments to clarify compatibility of eagle checkpoint in eagle.py
abhigoyal1997 53ab660
Merge branch 'vllm-project:main' into eagle
abhigoyal1997 7f46c68
fixing model_cls resolution in eagle
abhigoyal1997 5e5d214
fixing model_cls resolution in eagle
abhigoyal1997 17c0fc6
Merge branch 'main' into eagle
abhigoyal1997 ad04e7f
adding doctrings to EAGLE and Medusa models
abhigoyal1997 90bee1d
fixing hidden states handling in batch expansion
abhigoyal1997 88c20e6
making HiddenStates a dataclass and renaming last_non_bonus_hidden_st…
abhigoyal1997 3ff257b
Merge branch 'hidden_states_fix' of github.fkinternal.com:abhinav-goy…
abhigoyal1997 1753d9a
reformatting
abhigoyal1997 99484ae
adding acceptance rate test for large output length
abhigoyal1997 2e51385
fixing hidden states manipulation for batch expansion
abhigoyal1997 faa2e28
Merge branch 'hidden_states_fix' of github.fkinternal.com:abhinav-goy…
abhigoyal1997 d8bcff0
print acceptance rate in spec decode tests
abhigoyal1997 1654d4d
Updating HiddenStates to handle prefill step as well
abhigoyal1997 6954ead
changing expected acceptance rate for test
abhigoyal1997 08b3cd5
Merge branch 'vllm-project:main' into hidden_states_fix
abhigoyal1997 5815ccc
Merge branch 'vllm-project:main' into hidden_states_fix
abhigoyal1997 601c816
Merge branch 'hidden_states_fix' into eagle
abhigoyal1997 df87143
Adding explanation for trucated vocab and merging main
abhigoyal1997 f906cef
formatting
abhigoyal1997 2147583
Merge branch 'main' of github.fkinternal.com:abhinav-goyal/vllm into …
abhigoyal1997 3febb95
Fixing compatibility of `worker.multi_step_worker.MultiStepWorker` wi…
abhigoyal1997 90582e2
Merge branch 'main' into eagle
abhigoyal1997 af5552b
Merge branch 'main' into eagle
abhigoyal1997 284468d
adding comment
abhigoyal1997 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
"""This docstring details important information on the testing methodology. | ||
|
||
Most of the tests rely on "greedy equality", where we expect the output of | ||
speculative decoding on a sequence to exactly match the output of normal non- | ||
speculative decoding. | ||
|
||
Since speculative decoding with rejection sampling guarantees that the output | ||
distribution matches the target model's output distribution (up to hardware | ||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy | ||
equality. | ||
|
||
However, we still need to verify below scenario could be passed: | ||
* Batch size 1 greedy equality | ||
* Batch size >1 greedy equality | ||
* Test greedy equality under preemption | ||
* Test greedy equality under various number of speculative tokens. | ||
|
||
With those tests, we can say at least, EAGLE would not break the | ||
correctess for the target model outputs. | ||
""" | ||
|
||
import pytest | ||
|
||
from .conftest import run_greedy_equality_correctness_test | ||
|
||
# main model | ||
MAIN_MODEL = "JackFram/llama-68m" | ||
|
||
# speculative model | ||
SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random" | ||
|
||
# max. number of speculative tokens: this corresponds to | ||
# num_heads in the config.json of the speculator model. | ||
MAX_SPEC_TOKENS = 4 | ||
|
||
# precision | ||
PRECISION = "float32" | ||
abhigoyal1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
# Skip cuda graph recording for fast test. | ||
"enforce_eager": True, | ||
abhigoyal1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Required for spec decode. | ||
"use_v2_block_manager": True, | ||
|
||
# Print spec metrics. | ||
"disable_log_stats": False, | ||
|
||
# Precision | ||
"dtype": PRECISION, | ||
|
||
# Main model | ||
"model": MAIN_MODEL, | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("test_llm_kwargs", [ | ||
{ | ||
"speculative_model": SPEC_MODEL, | ||
"num_speculative_tokens": MAX_SPEC_TOKENS, | ||
}, | ||
]) | ||
@pytest.mark.parametrize("output_len", [ | ||
128, | ||
]) | ||
@pytest.mark.parametrize("batch_size", [1, 32]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, | ||
abhigoyal1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
batch_size: int, output_len: int): | ||
"""Verify greedy equality with different batch size.""" | ||
abhigoyal1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
run_greedy_equality_correctness_test(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size, | ||
max_output_len=output_len, | ||
force_output_len=True) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
"block_size": 8, | ||
# 2 for small prompt, 256//8 for generated. | ||
"num_gpu_blocks_override": 2 + 256 // 8, | ||
"max_model_len": (2 + 256 // 8) * 8, | ||
|
||
# Skip cuda graph recording for fast test. | ||
"enforce_eager": True, | ||
|
||
# Required for spec decode. | ||
"use_v2_block_manager": True, | ||
|
||
# Precision | ||
"dtype": PRECISION, | ||
|
||
# Main model | ||
"model": MAIN_MODEL, | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("test_llm_kwargs", [ | ||
{ | ||
"speculative_model": SPEC_MODEL, | ||
"num_speculative_tokens": MAX_SPEC_TOKENS, | ||
}, | ||
]) | ||
@pytest.mark.parametrize( | ||
"output_len", | ||
[ | ||
# Use small output len for fast test. | ||
128, | ||
]) | ||
@pytest.mark.parametrize("batch_size", [4]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, | ||
abhigoyal1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
test_llm_generator, | ||
batch_size: int, | ||
output_len: int): | ||
"""Verify greedy equality, even when some sequences are preempted mid- | ||
generation. | ||
""" | ||
run_greedy_equality_correctness_test(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size, | ||
max_output_len=output_len, | ||
force_output_len=True) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
# Skip cuda graph recording for fast test. | ||
"enforce_eager": True, | ||
|
||
# Required for spec decode. | ||
"use_v2_block_manager": True, | ||
|
||
# Precision | ||
"dtype": PRECISION, | ||
|
||
# Main model | ||
"model": MAIN_MODEL, | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize( | ||
"test_llm_kwargs", | ||
[ | ||
{ | ||
"speculative_model": SPEC_MODEL, | ||
"num_speculative_tokens": k, | ||
} | ||
# Try a range of num. speculative tokens | ||
for k in range(1, 1 + MAX_SPEC_TOKENS) | ||
]) | ||
@pytest.mark.parametrize("batch_size", [2]) | ||
@pytest.mark.parametrize( | ||
"output_len", | ||
[ | ||
# Use smaller output len for fast test. | ||
32, | ||
]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_mlp_different_k(baseline_llm_generator, test_llm_generator, | ||
batch_size: int, output_len: int): | ||
"""Verify that mlp speculative decoding produces exact equality | ||
abhigoyal1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
to without spec decode with different values of num_speculative_tokens. | ||
""" | ||
run_greedy_equality_correctness_test(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size, | ||
max_output_len=output_len, | ||
force_output_len=True) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
# Skip cuda graph recording for fast test. | ||
"enforce_eager": True, | ||
|
||
# Required for spec decode. | ||
"use_v2_block_manager": True, | ||
|
||
# Precision | ||
"dtype": PRECISION, | ||
|
||
# Main model | ||
"model": MAIN_MODEL, | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("test_llm_kwargs", | ||
[{ | ||
"speculative_model": SPEC_MODEL, | ||
"num_speculative_tokens": MAX_SPEC_TOKENS, | ||
"speculative_disable_by_batch_size": 4 | ||
}]) | ||
@pytest.mark.parametrize("batch_size", [1, 5]) | ||
@pytest.mark.parametrize( | ||
"output_len", | ||
[ | ||
# Use smaller output len for fast test. | ||
32, | ||
]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator, | ||
batch_size: int, output_len: int): | ||
"""Verify that mlp speculative decoding produces exact equality | ||
to without spec decode when speculation is disabled for large | ||
batch sizes. | ||
""" | ||
run_greedy_equality_correctness_test(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size, | ||
max_output_len=output_len, | ||
force_output_len=True) | ||
|
||
|
||
abhigoyal1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if __name__ == "__main__": | ||
import pytest | ||
pytest.main([__file__]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from typing import Iterable, List, Optional, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.attention.backends.abstract import AttentionMetadata | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.model_executor.models import ModelRegistry | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
from vllm.sequence import IntermediateTensors, SamplerOutput | ||
from vllm.transformers_utils.configs.eagle import EAGLEConfig | ||
|
||
|
||
class EAGLE(nn.Module): | ||
abhigoyal1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: | ||
super().__init__() | ||
self.config = config | ||
|
||
architectures = getattr(self.config.model, "architectures", []) | ||
for arch in architectures: | ||
model_cls = ModelRegistry.load_model_cls(arch) | ||
if model_cls is not None: | ||
break | ||
|
||
self.model = model_cls(self.config.model, *args, **kwargs) | ||
self.fc = nn.Linear(config.model.hidden_size * 2, | ||
config.model.hidden_size, | ||
bias=False) | ||
|
||
self.token_map = None | ||
abhigoyal1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@property | ||
def sampler(self): | ||
return self.model.sampler | ||
|
||
def forward( | ||
abhigoyal1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: List[torch.Tensor], | ||
attn_metadata: AttentionMetadata, | ||
previous_hidden_states: torch.Tensor, | ||
intermediate_tensors: Optional[IntermediateTensors] = None, | ||
) -> torch.Tensor: | ||
|
||
tok_embeds = self.model.model.embed_tokens(input_ids) | ||
abhigoyal1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
inputs_embeds = self.fc( | ||
torch.cat([tok_embeds, previous_hidden_states], dim=-1)) | ||
|
||
inputs_embeds[positions == 0] = 0 # masking inputs at position=0 | ||
|
||
hidden_states = self.model.model( | ||
input_ids=None, | ||
inputs_embeds=inputs_embeds, | ||
positions=positions, | ||
kv_caches=kv_caches, | ||
attn_metadata=attn_metadata, | ||
intermediate_tensors=intermediate_tensors) | ||
return hidden_states | ||
|
||
def compute_logits(self, hidden_states: torch.Tensor, | ||
sampling_metadata: SamplingMetadata) -> torch.Tensor: | ||
return self.model.compute_logits(hidden_states, sampling_metadata) | ||
|
||
def sample( | ||
self, | ||
logits: torch.Tensor, | ||
sampling_metadata: SamplingMetadata, | ||
) -> Optional[SamplerOutput]: | ||
next_tokens = self.sampler(logits, sampling_metadata) | ||
return next_tokens | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
model_weights = [] | ||
for name, loaded_weight in weights: | ||
if name.startswith("fc."): | ||
weight_loader = getattr(self.fc.weight, "weight_loader", | ||
default_weight_loader) | ||
weight_loader(self.fc.weight, loaded_weight) | ||
elif name.startswith("lm_head.") or name.startswith("model."): | ||
model_weights.append((name, loaded_weight)) | ||
else: | ||
model_weights.append((f"model.{name}", loaded_weight)) | ||
|
||
self.model.load_weights(model_weights) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.