Skip to content

Commit 824100c

Browse files
🏰 [vllm] Support base_url parameter for vLLM client initialization (#3324)
Co-authored-by: Quentin GallouĂ©dec <[email protected]> Co-authored-by: Quentin GallouĂ©dec <[email protected]>
1 parent 4e7f0a5 commit 824100c

File tree

4 files changed

+142
-22
lines changed

4 files changed

+142
-22
lines changed

‎tests/test_vllm_client_server.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,85 @@ def tearDownClass(cls):
132132
cls.server_process.wait()
133133

134134

135+
# Same as above but using base_url to instantiate the client.
136+
@pytest.mark.slow
137+
@require_torch_multi_accelerator
138+
class TestVLLMClientServerBaseURL(unittest.TestCase):
139+
model_id = "Qwen/Qwen2.5-1.5B"
140+
141+
@classmethod
142+
def setUpClass(cls):
143+
# We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1"
144+
env = os.environ.copy()
145+
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
146+
env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1
147+
148+
# Start the server process
149+
cls.server_process = subprocess.Popen(
150+
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
151+
)
152+
153+
# Initialize the client
154+
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=240)
155+
cls.client.init_communicator()
156+
157+
def test_generate(self):
158+
prompts = ["Hello, AI!", "Tell me a joke"]
159+
outputs = self.client.generate(prompts)
160+
161+
# Check that the output is a list
162+
self.assertIsInstance(outputs, list)
163+
164+
# Check that the number of generated sequences is equal to the number of prompts
165+
self.assertEqual(len(outputs), len(prompts))
166+
167+
# Check that the generated sequences are lists of integers
168+
for seq in outputs:
169+
self.assertTrue(all(isinstance(tok, int) for tok in seq))
170+
171+
def test_generate_with_params(self):
172+
prompts = ["Hello, AI!", "Tell me a joke"]
173+
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)
174+
175+
# Check that the output is a list
176+
self.assertIsInstance(outputs, list)
177+
178+
# Check that the number of generated sequences is 2 times the number of prompts
179+
self.assertEqual(len(outputs), 2 * len(prompts))
180+
181+
# Check that the generated sequences are lists of integers
182+
for seq in outputs:
183+
self.assertTrue(all(isinstance(tok, int) for tok in seq))
184+
185+
# Check that the length of the generated sequences is less than or equal to 32
186+
for seq in outputs:
187+
self.assertLessEqual(len(seq), 32)
188+
189+
def test_update_model_params(self):
190+
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
191+
self.client.update_model_params(model)
192+
193+
def test_reset_prefix_cache(self):
194+
# Test resetting the prefix cache
195+
self.client.reset_prefix_cache()
196+
197+
@classmethod
198+
def tearDownClass(cls):
199+
super().tearDownClass()
200+
201+
# Close the client
202+
cls.client.close_communicator()
203+
204+
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
205+
# kill the server process and its children explicitly.
206+
parent = psutil.Process(cls.server_process.pid)
207+
children = parent.children(recursive=True)
208+
for child in children:
209+
child.send_signal(signal.SIGTERM)
210+
cls.server_process.terminate()
211+
cls.server_process.wait()
212+
213+
135214
@pytest.mark.slow
136215
@require_3_accelerators
137216
class TestVLLMClientServerTP(unittest.TestCase):

‎trl/extras/vllm_client.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
import atexit
1616
import logging
17+
import socket
1718
import time
1819
from typing import Optional
20+
from urllib.parse import urlparse
1921

2022
import torch
2123
from torch import nn
@@ -47,10 +49,13 @@ class VLLMClient:
4749
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
4850
4951
Args:
52+
base_url (`str` or `None`, *optional*, defaults to `None`):
53+
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `host` and `server_port` are
54+
ignored.
5055
host (`str`, *optional*, defaults to `"0.0.0.0"`):
51-
IP address of the vLLM server.
56+
IP address of the vLLM server. Ignored if `base_url` is provided.
5257
server_port (`int`, *optional*, defaults to `8000`):
53-
Port number of the vLLM server.
58+
Port number of the vLLM server. Ignored if `base_url` is provided.
5459
group_port (`int`, *optional*, defaults to `51216`):
5560
Port number for the weight update group.
5661
connection_timeout (`float`, *optional*, defaults to `0.0`):
@@ -81,19 +86,42 @@ class VLLMClient:
8186
>>> client.init_communicator()
8287
>>> client.update_model_params(model)
8388
```
89+
90+
There are several ways to initialize the client:
91+
92+
```python
93+
VLLMClient(base_url="http://localhost:8000")
94+
VLLMClient(base_url="http://192.168.1.100:8000")
95+
VLLMClient(host="localhost", server_port=8000)
96+
VLLMClient(host="192.168.1.100", server_port=8000)
97+
```
8498
"""
8599

86100
def __init__(
87-
self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0
101+
self,
102+
base_url: Optional[str] = None,
103+
host: str = "0.0.0.0",
104+
server_port: int = 8000,
105+
group_port: int = 51216,
106+
connection_timeout: float = 0.0,
88107
):
89108
if not is_requests_available():
90109
raise ImportError("requests is not installed. Please install it with `pip install requests`.")
91110
if not is_vllm_available():
92111
raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.")
93112

94113
self.session = requests.Session()
95-
self.host = host
96-
self.server_port = server_port
114+
115+
if base_url is not None:
116+
# Parse the base_url to extract host and port
117+
parsed_url = urlparse(base_url)
118+
self.host = socket.gethostbyname(parsed_url.hostname)
119+
scheme = parsed_url.scheme or "http"
120+
self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}"
121+
else:
122+
self.host = host
123+
self.server_port = server_port
124+
self.base_url = f"http://{self.host}:{self.server_port}"
97125
self.group_port = group_port
98126
self.check_server(connection_timeout) # check server and fail after timeout
99127

@@ -108,7 +136,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
108136
total_timeout (`float`, *optional*, defaults to `0.0`):
109137
Total timeout duration in seconds.
110138
"""
111-
url = f"http://{self.host}:{self.server_port}/health/"
139+
url = f"{self.base_url}/health/"
112140
start_time = time.time() # Record the start time
113141

114142
while True:
@@ -119,11 +147,13 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
119147
elapsed_time = time.time() - start_time
120148
if elapsed_time >= total_timeout:
121149
raise ConnectionError(
122-
f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
123-
"seconds. Make sure the server is running by running `trl vllm-serve`."
150+
f"The vLLM server can't be reached at {self.base_url} after {total_timeout} seconds. Make "
151+
"sure the server is running by running `trl vllm-serve`."
124152
) from exc
125153
else:
126154
if response.status_code == 200:
155+
if "X-Forwarded-For" in response.headers:
156+
self.host = response.headers["X-Forwarded-For"]
127157
logger.info("Server is up!")
128158
return None
129159

@@ -170,7 +200,7 @@ def generate(
170200
`list[list[int]]`:
171201
List of lists of token IDs representing the model-generated completions for each prompt.
172202
"""
173-
url = f"http://{self.host}:{self.server_port}/generate/"
203+
url = f"{self.base_url}/generate/"
174204
response = self.session.post(
175205
url,
176206
json={
@@ -195,7 +225,7 @@ def init_communicator(self):
195225
Initializes the weight update group in a distributed setup for model synchronization.
196226
"""
197227
# Get the world size from the server
198-
url = f"http://{self.host}:{self.server_port}/get_world_size/"
228+
url = f"{self.base_url}/get_world_size/"
199229
response = requests.get(url)
200230
if response.status_code == 200:
201231
vllm_world_size = response.json()["world_size"]
@@ -206,7 +236,7 @@ def init_communicator(self):
206236
self.rank = vllm_world_size # the client's rank is the last process
207237

208238
# Initialize weight update group
209-
url = f"http://{self.host}:{self.server_port}/init_communicator/"
239+
url = f"{self.base_url}/init_communicator/"
210240
# In the server side, the host is set to 0.0.0.0
211241
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size})
212242
if response.status_code != 200:
@@ -235,7 +265,7 @@ def update_named_param(self, name: str, weights: torch.Tensor):
235265
Tensor containing the updated weights.
236266
"""
237267
dtype, shape = str(weights.dtype), tuple(weights.shape)
238-
url = f"http://{self.host}:{self.server_port}/update_named_param/"
268+
url = f"{self.base_url}/update_named_param/"
239269
response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape})
240270
if response.status_code != 200:
241271
raise Exception(f"Request failed: {response.status_code}, {response.text}")
@@ -260,7 +290,7 @@ def reset_prefix_cache(self):
260290
"""
261291
Resets the prefix cache for the model.
262292
"""
263-
url = f"http://{self.host}:{self.server_port}/reset_prefix_cache/"
293+
url = f"{self.base_url}/reset_prefix_cache/"
264294
response = self.session.post(url)
265295
if response.status_code != 200:
266296
raise Exception(f"Request failed: {response.status_code}, {response.text}")
@@ -269,7 +299,7 @@ def close_communicator(self):
269299
"""
270300
Closes the weight update group and cleans up the communication group.
271301
"""
272-
url = f"http://{self.host}:{self.server_port}/close_communicator/"
302+
url = f"{self.base_url}/close_communicator/"
273303

274304
try:
275305
response = self.session.post(url)

‎trl/trainer/grpo_config.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,13 @@ class GRPOConfig(TrainingArguments):
104104
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
105105
106106
> Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
107-
107+
vllm_server_base_url (`str` or `None`, *optional*, defaults to `None`):
108+
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and
109+
`vllm_server_port` are ignored.
108110
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
109-
Host of the vLLM server to connect to.
111+
Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
110112
vllm_server_port (`int`, *optional*, defaults to `8000`):
111-
Port of the vLLM server to connect to.
113+
Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
112114
vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
113115
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
114116
timeout, a `ConnectionError` is raised.
@@ -320,6 +322,13 @@ class GRPOConfig(TrainingArguments):
320322
"generation instead of the default model.generate(). Requires `vllm` to be installed."
321323
},
322324
)
325+
vllm_server_base_url: Optional[str] = field(
326+
default=None,
327+
metadata={
328+
"help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` "
329+
"and `vllm_server_port` are ignored."
330+
},
331+
)
323332
vllm_mode: str = field(
324333
default="server",
325334
metadata={
@@ -338,11 +347,11 @@ class GRPOConfig(TrainingArguments):
338347
# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
339348
vllm_server_host: str = field(
340349
default="0.0.0.0",
341-
metadata={"help": "Host of the vLLM server to connect to."},
350+
metadata={"help": "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
342351
)
343352
vllm_server_port: int = field(
344353
default=8000,
345-
metadata={"help": "Port of the vLLM server to connect to."},
354+
metadata={"help": "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
346355
)
347356
vllm_server_timeout: float = field(
348357
default=240.0,

‎trl/trainer/grpo_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,11 @@ def data_collator(features): # No data collation is needed in GRPO
618618
)
619619

620620
if self.vllm_mode == "server" and self.accelerator.is_main_process:
621-
self.vllm_client = VLLMClient(
622-
args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout
623-
)
621+
if args.vllm_server_base_url is not None:
622+
base_url = args.vllm_server_base_url
623+
else:
624+
base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
625+
self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
624626
self.vllm_client.init_communicator()
625627

626628
elif self.vllm_mode == "colocate":

0 commit comments

Comments
 (0)