Skip to content

🧪 add long context test #330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion tests/e2e/test_spyre_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

import pytest
from spyre_util import generate_spyre_vllm_output, get_spyre_model_list
from spyre_util import (generate_spyre_vllm_output, get_chicken_soup_prompts,
get_spyre_model_list)
from vllm import SamplingParams


Expand Down Expand Up @@ -41,3 +42,57 @@ def test_cb_max_tokens(
max_num_seqs=2,
use_cb=True,
monkeypatch=monkeypatch)


@pytest.mark.cb
@pytest.mark.spyre
@pytest.mark.xfail # TODO: remove once a spyre-base image supports this
@pytest.mark.parametrize("model", get_spyre_model_list())
def test_continuous_batching_with_long_contexts(model, monkeypatch):
"""Tests that continuous batching generates the same outputs on the spyre
cards as it does on cpu, when the max context length is set to 4k.
This ensures that the compiler is generating the correct programs for long
context cases, but we test here with small prompts for speed.

Importantly, we're generating the cpu results to compare against using vllm
as well, instead of using transformers directly. This ensures that the model
code is all the same, and the only difference is the torch compilation
backend.
"""
max_model_len = 4096
max_num_seqs = 4
prompts = get_chicken_soup_prompts(4)

sampling_params = SamplingParams(max_tokens=20,
temperature=0,
ignore_eos=True,
logprobs=0)

vllm_cpu_results = generate_spyre_vllm_output(
model=model,
prompts=prompts,
max_model_len=max_model_len,
block_size=max_model_len,
sampling_params=sampling_params,
tensor_parallel_size=1,
backend="eager",
max_num_seqs=max_num_seqs,
use_cb=True,
monkeypatch=monkeypatch)

vllm_spyre_results = generate_spyre_vllm_output(
model=model,
prompts=prompts,
max_model_len=max_model_len,
block_size=max_model_len,
sampling_params=sampling_params,
tensor_parallel_size=1,
backend="sendnn",
max_num_seqs=max_num_seqs,
use_cb=True,
monkeypatch=monkeypatch)

for i in range(len(vllm_cpu_results)):
# As long as no sequences have top candidate tokens with very close
# logprobs, the generated text should be identical.
assert vllm_cpu_results[i]["text"] == vllm_spyre_results[i]["text"]
Loading