|
4 | 4 | """
|
5 | 5 |
|
6 | 6 | import pytest
|
7 |
| -from spyre_util import generate_spyre_vllm_output, get_spyre_model_list |
| 7 | +from spyre_util import (generate_spyre_vllm_output, get_chicken_soup_prompts, |
| 8 | + get_spyre_model_list) |
8 | 9 | from vllm import SamplingParams
|
9 | 10 |
|
10 | 11 |
|
@@ -41,3 +42,57 @@ def test_cb_max_tokens(
|
41 | 42 | max_num_seqs=2,
|
42 | 43 | use_cb=True,
|
43 | 44 | monkeypatch=monkeypatch)
|
| 45 | + |
| 46 | + |
| 47 | +@pytest.mark.cb |
| 48 | +@pytest.mark.spyre |
| 49 | +@pytest.mark.xfail # TODO: remove once a spyre-base image supports this |
| 50 | +@pytest.mark.parametrize("model", get_spyre_model_list()) |
| 51 | +def test_continuous_batching_with_long_contexts(model, monkeypatch): |
| 52 | + """Tests that continuous batching generates the same outputs on the spyre |
| 53 | + cards as it does on cpu, when the max context length is set to 4k. |
| 54 | + This ensures that the compiler is generating the correct programs for long |
| 55 | + context cases, but we test here with small prompts for speed. |
| 56 | +
|
| 57 | + Importantly, we're generating the cpu results to compare against using vllm |
| 58 | + as well, instead of using transformers directly. This ensures that the model |
| 59 | + code is all the same, and the only difference is the torch compilation |
| 60 | + backend. |
| 61 | + """ |
| 62 | + max_model_len = 4096 |
| 63 | + max_num_seqs = 4 |
| 64 | + prompts = get_chicken_soup_prompts(4) |
| 65 | + |
| 66 | + sampling_params = SamplingParams(max_tokens=20, |
| 67 | + temperature=0, |
| 68 | + ignore_eos=True, |
| 69 | + logprobs=0) |
| 70 | + |
| 71 | + vllm_cpu_results = generate_spyre_vllm_output( |
| 72 | + model=model, |
| 73 | + prompts=prompts, |
| 74 | + max_model_len=max_model_len, |
| 75 | + block_size=max_model_len, |
| 76 | + sampling_params=sampling_params, |
| 77 | + tensor_parallel_size=1, |
| 78 | + backend="eager", |
| 79 | + max_num_seqs=max_num_seqs, |
| 80 | + use_cb=True, |
| 81 | + monkeypatch=monkeypatch) |
| 82 | + |
| 83 | + vllm_spyre_results = generate_spyre_vllm_output( |
| 84 | + model=model, |
| 85 | + prompts=prompts, |
| 86 | + max_model_len=max_model_len, |
| 87 | + block_size=max_model_len, |
| 88 | + sampling_params=sampling_params, |
| 89 | + tensor_parallel_size=1, |
| 90 | + backend="sendnn", |
| 91 | + max_num_seqs=max_num_seqs, |
| 92 | + use_cb=True, |
| 93 | + monkeypatch=monkeypatch) |
| 94 | + |
| 95 | + for i in range(len(vllm_cpu_results)): |
| 96 | + # As long as no sequences have top candidate tokens with very close |
| 97 | + # logprobs, the generated text should be identical. |
| 98 | + assert vllm_cpu_results[i]["text"] == vllm_spyre_results[i]["text"] |
0 commit comments