16
16
import os
17
17
import re
18
18
import subprocess
19
+ import tempfile
19
20
20
21
import pytest
21
22
import yaml
@@ -1203,6 +1204,43 @@ def run_disaggregated_benchmark(example_dir,
1203
1204
workers_proc .wait ()
1204
1205
1205
1206
1207
+ def get_config_for_benchmark (model_root , backend ):
1208
+ serve_config = {
1209
+ "model" : model_root ,
1210
+ "hostname" : "localhost" ,
1211
+ "port" : 8000 ,
1212
+ "backend" : "pytorch" ,
1213
+ "context_servers" : {
1214
+ "num_instances" : 1 ,
1215
+ "max_batch_size" : 2 ,
1216
+ "max_num_tokens" : 384 ,
1217
+ "max_seq_len" : 320 ,
1218
+ "tensor_parallel_size" : 1 ,
1219
+ "pipeline_parallel_size" : 1 ,
1220
+ "disable_overlap_scheduler" : True ,
1221
+ "cache_transceiver_config" : {
1222
+ "backend" : backend ,
1223
+ "max_tokens_in_buffer" : 512 ,
1224
+ },
1225
+ "urls" : ["localhost:8001" ]
1226
+ },
1227
+ "generation_servers" : {
1228
+ "num_instances" : 1 ,
1229
+ "tensor_parallel_size" : 1 ,
1230
+ "pipeline_parallel_size" : 1 ,
1231
+ "max_batch_size" : 2 ,
1232
+ "max_num_tokens" : 384 ,
1233
+ "max_seq_len" : 320 ,
1234
+ "cache_transceiver_config" : {
1235
+ "backend" : backend ,
1236
+ "max_tokens_in_buffer" : 512 ,
1237
+ },
1238
+ "urls" : ["localhost:8002" ]
1239
+ }
1240
+ }
1241
+ return serve_config
1242
+
1243
+
1206
1244
@pytest .mark .parametrize ("benchmark_model_root" , [
1207
1245
'DeepSeek-V3-Lite-fp8' , 'DeepSeek-V3-Lite-bf16' , 'llama-v3-8b-hf' ,
1208
1246
'llama-3.1-8b-instruct-hf-fp8'
@@ -1211,32 +1249,28 @@ def run_disaggregated_benchmark(example_dir,
1211
1249
def test_disaggregated_benchmark_on_diff_backends (
1212
1250
disaggregated_test_root , disaggregated_example_root , llm_venv ,
1213
1251
benchmark_model_root , benchmark_root , shared_gpt_path ):
1214
- base_config_path = os .path .join (os .path .dirname (__file__ ), "test_configs" ,
1215
- "disagg_config_for_benchmark.yaml" )
1216
- with open (base_config_path , 'r' , encoding = 'utf-8' ) as f :
1217
- config = yaml .load (f , Loader = yaml .SafeLoader )
1218
- config ["model" ] = benchmark_model_root
1219
- with open ("ucx_config.yaml" , 'w' , encoding = 'utf-8' ) as ucx_config :
1220
- yaml .dump (config , ucx_config )
1221
- config ["context_servers" ]["cache_transceiver_config" ][
1222
- "backend" ] = "nixl"
1223
- config ["generation_servers" ]["cache_transceiver_config" ][
1224
- "backend" ] = "nixl"
1225
- with open ("nixl_config.yaml" , 'w' , encoding = 'utf-8' ) as nixl_config :
1226
- yaml .dump (config , nixl_config )
1252
+ nixl_config = get_config_for_benchmark (benchmark_model_root , "nixl" )
1253
+ ucx_config = get_config_for_benchmark (benchmark_model_root , "ucx" )
1254
+ temp_dir = tempfile .TemporaryDirectory ()
1255
+ nixl_config_path = os .path .join (temp_dir .name , "nixl_config.yaml" )
1256
+ ucx_config_path = os .path .join (temp_dir .name , "ucx_config.yaml" )
1257
+ with open (nixl_config_path , 'w' , encoding = 'utf-8' ) as f :
1258
+ yaml .dump (nixl_config , f )
1259
+ with open (ucx_config_path , 'w' , encoding = 'utf-8' ) as f :
1260
+ yaml .dump (ucx_config , f )
1227
1261
1228
1262
env = llm_venv ._new_env .copy ()
1229
1263
nixl_e2el , nixl_ttft = run_disaggregated_benchmark (
1230
1264
disaggregated_example_root ,
1231
- f" { os . path . dirname ( __file__ ) } /nixl_config.yaml" ,
1265
+ nixl_config_path ,
1232
1266
benchmark_root ,
1233
1267
benchmark_model_root ,
1234
1268
shared_gpt_path ,
1235
1269
env = env ,
1236
1270
cwd = llm_venv .get_working_directory ())
1237
1271
ucx_e2el , ucx_ttft = run_disaggregated_benchmark (
1238
1272
disaggregated_example_root ,
1239
- f" { os . path . dirname ( __file__ ) } /ucx_config.yaml" ,
1273
+ ucx_config_path ,
1240
1274
benchmark_root ,
1241
1275
benchmark_model_root ,
1242
1276
shared_gpt_path ,
0 commit comments