Skip to content

Commit 18c2bba

Browse files
committed
fix some tests
Signed-off-by: Bo Deng <[email protected]>
1 parent f7345bf commit 18c2bba

File tree

2 files changed

+49
-44
lines changed

2 files changed

+49
-44
lines changed

tests/integration/defs/disaggregated/test_configs/disagg_config_for_benchmark.yaml

Lines changed: 0 additions & 29 deletions
This file was deleted.

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import re
1818
import subprocess
19+
import tempfile
1920

2021
import pytest
2122
import yaml
@@ -1203,6 +1204,43 @@ def run_disaggregated_benchmark(example_dir,
12031204
workers_proc.wait()
12041205

12051206

1207+
def get_config_for_benchmark(model_root, backend):
1208+
serve_config = {
1209+
"model": model_root,
1210+
"hostname": "localhost",
1211+
"port": 8000,
1212+
"backend": "pytorch",
1213+
"context_servers": {
1214+
"num_instances": 1,
1215+
"max_batch_size": 2,
1216+
"max_num_tokens": 384,
1217+
"max_seq_len": 320,
1218+
"tensor_parallel_size": 1,
1219+
"pipeline_parallel_size": 1,
1220+
"disable_overlap_scheduler": True,
1221+
"cache_transceiver_config": {
1222+
"backend": backend,
1223+
"max_tokens_in_buffer": 512,
1224+
},
1225+
"urls": ["localhost:8001"]
1226+
},
1227+
"generation_servers": {
1228+
"num_instances": 1,
1229+
"tensor_parallel_size": 1,
1230+
"pipeline_parallel_size": 1,
1231+
"max_batch_size": 2,
1232+
"max_num_tokens": 384,
1233+
"max_seq_len": 320,
1234+
"cache_transceiver_config": {
1235+
"backend": backend,
1236+
"max_tokens_in_buffer": 512,
1237+
},
1238+
"urls": ["localhost:8002"]
1239+
}
1240+
}
1241+
return serve_config
1242+
1243+
12061244
@pytest.mark.parametrize("benchmark_model_root", [
12071245
'DeepSeek-V3-Lite-fp8', 'DeepSeek-V3-Lite-bf16', 'llama-v3-8b-hf',
12081246
'llama-3.1-8b-instruct-hf-fp8'
@@ -1211,32 +1249,28 @@ def run_disaggregated_benchmark(example_dir,
12111249
def test_disaggregated_benchmark_on_diff_backends(
12121250
disaggregated_test_root, disaggregated_example_root, llm_venv,
12131251
benchmark_model_root, benchmark_root, shared_gpt_path):
1214-
base_config_path = os.path.join(os.path.dirname(__file__), "test_configs",
1215-
"disagg_config_for_benchmark.yaml")
1216-
with open(base_config_path, 'r', encoding='utf-8') as f:
1217-
config = yaml.load(f, Loader=yaml.SafeLoader)
1218-
config["model"] = benchmark_model_root
1219-
with open("ucx_config.yaml", 'w', encoding='utf-8') as ucx_config:
1220-
yaml.dump(config, ucx_config)
1221-
config["context_servers"]["cache_transceiver_config"][
1222-
"backend"] = "nixl"
1223-
config["generation_servers"]["cache_transceiver_config"][
1224-
"backend"] = "nixl"
1225-
with open("nixl_config.yaml", 'w', encoding='utf-8') as nixl_config:
1226-
yaml.dump(config, nixl_config)
1252+
nixl_config = get_config_for_benchmark(benchmark_model_root, "nixl")
1253+
ucx_config = get_config_for_benchmark(benchmark_model_root, "ucx")
1254+
temp_dir = tempfile.TemporaryDirectory()
1255+
nixl_config_path = os.path.join(temp_dir.name, "nixl_config.yaml")
1256+
ucx_config_path = os.path.join(temp_dir.name, "ucx_config.yaml")
1257+
with open(nixl_config_path, 'w', encoding='utf-8') as f:
1258+
yaml.dump(nixl_config, f)
1259+
with open(ucx_config_path, 'w', encoding='utf-8') as f:
1260+
yaml.dump(ucx_config, f)
12271261

12281262
env = llm_venv._new_env.copy()
12291263
nixl_e2el, nixl_ttft = run_disaggregated_benchmark(
12301264
disaggregated_example_root,
1231-
f"{os.path.dirname(__file__)}/nixl_config.yaml",
1265+
nixl_config_path,
12321266
benchmark_root,
12331267
benchmark_model_root,
12341268
shared_gpt_path,
12351269
env=env,
12361270
cwd=llm_venv.get_working_directory())
12371271
ucx_e2el, ucx_ttft = run_disaggregated_benchmark(
12381272
disaggregated_example_root,
1239-
f"{os.path.dirname(__file__)}/ucx_config.yaml",
1273+
ucx_config_path,
12401274
benchmark_root,
12411275
benchmark_model_root,
12421276
shared_gpt_path,

0 commit comments

Comments
 (0)