Skip to content

Commit 9d7d963

Browse files
authored
🧪 add long context test (#330)
# Description This adds a failing test case for handling batch size 4 at 4k context length with continuous batching. This tests the `eager` backend against the `sendnn` backend, which is currently failing on `quay.io/ibm-aiu/spyre-base:2025_07_18-amd64` Signed-off-by: Joe Runde <[email protected]>
1 parent 3a7cc4b commit 9d7d963

File tree

1 file changed

+56
-1
lines changed

1 file changed

+56
-1
lines changed

tests/e2e/test_spyre_cb.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
"""
55

66
import pytest
7-
from spyre_util import generate_spyre_vllm_output, get_spyre_model_list
7+
from spyre_util import (generate_spyre_vllm_output, get_chicken_soup_prompts,
8+
get_spyre_model_list)
89
from vllm import SamplingParams
910

1011

@@ -41,3 +42,57 @@ def test_cb_max_tokens(
4142
max_num_seqs=2,
4243
use_cb=True,
4344
monkeypatch=monkeypatch)
45+
46+
47+
@pytest.mark.cb
48+
@pytest.mark.spyre
49+
@pytest.mark.xfail # TODO: remove once a spyre-base image supports this
50+
@pytest.mark.parametrize("model", get_spyre_model_list())
51+
def test_continuous_batching_with_long_contexts(model, monkeypatch):
52+
"""Tests that continuous batching generates the same outputs on the spyre
53+
cards as it does on cpu, when the max context length is set to 4k.
54+
This ensures that the compiler is generating the correct programs for long
55+
context cases, but we test here with small prompts for speed.
56+
57+
Importantly, we're generating the cpu results to compare against using vllm
58+
as well, instead of using transformers directly. This ensures that the model
59+
code is all the same, and the only difference is the torch compilation
60+
backend.
61+
"""
62+
max_model_len = 4096
63+
max_num_seqs = 4
64+
prompts = get_chicken_soup_prompts(4)
65+
66+
sampling_params = SamplingParams(max_tokens=20,
67+
temperature=0,
68+
ignore_eos=True,
69+
logprobs=0)
70+
71+
vllm_cpu_results = generate_spyre_vllm_output(
72+
model=model,
73+
prompts=prompts,
74+
max_model_len=max_model_len,
75+
block_size=max_model_len,
76+
sampling_params=sampling_params,
77+
tensor_parallel_size=1,
78+
backend="eager",
79+
max_num_seqs=max_num_seqs,
80+
use_cb=True,
81+
monkeypatch=monkeypatch)
82+
83+
vllm_spyre_results = generate_spyre_vllm_output(
84+
model=model,
85+
prompts=prompts,
86+
max_model_len=max_model_len,
87+
block_size=max_model_len,
88+
sampling_params=sampling_params,
89+
tensor_parallel_size=1,
90+
backend="sendnn",
91+
max_num_seqs=max_num_seqs,
92+
use_cb=True,
93+
monkeypatch=monkeypatch)
94+
95+
for i in range(len(vllm_cpu_results)):
96+
# As long as no sequences have top candidate tokens with very close
97+
# logprobs, the generated text should be identical.
98+
assert vllm_cpu_results[i]["text"] == vllm_spyre_results[i]["text"]

0 commit comments

Comments
 (0)