14
14
15
15
import atexit
16
16
import logging
17
+ import socket
17
18
import time
18
19
from typing import Optional
20
+ from urllib .parse import urlparse
19
21
20
22
import torch
21
23
from torch import nn
@@ -47,10 +49,13 @@ class VLLMClient:
47
49
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
48
50
49
51
Args:
52
+ base_url (`str` or `None`, *optional*, defaults to `None`):
53
+ Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `host` and `server_port` are
54
+ ignored.
50
55
host (`str`, *optional*, defaults to `"0.0.0.0"`):
51
- IP address of the vLLM server.
56
+ IP address of the vLLM server. Ignored if `base_url` is provided.
52
57
server_port (`int`, *optional*, defaults to `8000`):
53
- Port number of the vLLM server.
58
+ Port number of the vLLM server. Ignored if `base_url` is provided.
54
59
group_port (`int`, *optional*, defaults to `51216`):
55
60
Port number for the weight update group.
56
61
connection_timeout (`float`, *optional*, defaults to `0.0`):
@@ -81,19 +86,42 @@ class VLLMClient:
81
86
>>> client.init_communicator()
82
87
>>> client.update_model_params(model)
83
88
```
89
+
90
+ There are several ways to initialize the client:
91
+
92
+ ```python
93
+ VLLMClient(base_url="http://localhost:8000")
94
+ VLLMClient(base_url="http://192.168.1.100:8000")
95
+ VLLMClient(host="localhost", server_port=8000)
96
+ VLLMClient(host="192.168.1.100", server_port=8000)
97
+ ```
84
98
"""
85
99
86
100
def __init__ (
87
- self , host : str = "0.0.0.0" , server_port : int = 8000 , group_port : int = 51216 , connection_timeout : float = 0.0
101
+ self ,
102
+ base_url : Optional [str ] = None ,
103
+ host : str = "0.0.0.0" ,
104
+ server_port : int = 8000 ,
105
+ group_port : int = 51216 ,
106
+ connection_timeout : float = 0.0 ,
88
107
):
89
108
if not is_requests_available ():
90
109
raise ImportError ("requests is not installed. Please install it with `pip install requests`." )
91
110
if not is_vllm_available ():
92
111
raise ImportError ("vLLM is not installed. Please install it with `pip install vllm`." )
93
112
94
113
self .session = requests .Session ()
95
- self .host = host
96
- self .server_port = server_port
114
+
115
+ if base_url is not None :
116
+ # Parse the base_url to extract host and port
117
+ parsed_url = urlparse (base_url )
118
+ self .host = socket .gethostbyname (parsed_url .hostname )
119
+ scheme = parsed_url .scheme or "http"
120
+ self .base_url = f"{ scheme } ://{ parsed_url .netloc } { parsed_url .path } "
121
+ else :
122
+ self .host = host
123
+ self .server_port = server_port
124
+ self .base_url = f"http://{ self .host } :{ self .server_port } "
97
125
self .group_port = group_port
98
126
self .check_server (connection_timeout ) # check server and fail after timeout
99
127
@@ -108,7 +136,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
108
136
total_timeout (`float`, *optional*, defaults to `0.0`):
109
137
Total timeout duration in seconds.
110
138
"""
111
- url = f"http:// { self .host } : { self . server_port } /health/"
139
+ url = f"{ self .base_url } /health/"
112
140
start_time = time .time () # Record the start time
113
141
114
142
while True :
@@ -119,11 +147,13 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
119
147
elapsed_time = time .time () - start_time
120
148
if elapsed_time >= total_timeout :
121
149
raise ConnectionError (
122
- f"The vLLM server can't be reached at { self .host } : { self . server_port } after { total_timeout } "
123
- "seconds. Make sure the server is running by running `trl vllm-serve`."
150
+ f"The vLLM server can't be reached at { self .base_url } after { total_timeout } seconds. Make "
151
+ "sure the server is running by running `trl vllm-serve`."
124
152
) from exc
125
153
else :
126
154
if response .status_code == 200 :
155
+ if "X-Forwarded-For" in response .headers :
156
+ self .host = response .headers ["X-Forwarded-For" ]
127
157
logger .info ("Server is up!" )
128
158
return None
129
159
@@ -170,7 +200,7 @@ def generate(
170
200
`list[list[int]]`:
171
201
List of lists of token IDs representing the model-generated completions for each prompt.
172
202
"""
173
- url = f"http:// { self .host } : { self . server_port } /generate/"
203
+ url = f"{ self .base_url } /generate/"
174
204
response = self .session .post (
175
205
url ,
176
206
json = {
@@ -195,7 +225,7 @@ def init_communicator(self):
195
225
Initializes the weight update group in a distributed setup for model synchronization.
196
226
"""
197
227
# Get the world size from the server
198
- url = f"http:// { self .host } : { self . server_port } /get_world_size/"
228
+ url = f"{ self .base_url } /get_world_size/"
199
229
response = requests .get (url )
200
230
if response .status_code == 200 :
201
231
vllm_world_size = response .json ()["world_size" ]
@@ -206,7 +236,7 @@ def init_communicator(self):
206
236
self .rank = vllm_world_size # the client's rank is the last process
207
237
208
238
# Initialize weight update group
209
- url = f"http:// { self .host } : { self . server_port } /init_communicator/"
239
+ url = f"{ self .base_url } /init_communicator/"
210
240
# In the server side, the host is set to 0.0.0.0
211
241
response = self .session .post (url , json = {"host" : "0.0.0.0" , "port" : self .group_port , "world_size" : world_size })
212
242
if response .status_code != 200 :
@@ -235,7 +265,7 @@ def update_named_param(self, name: str, weights: torch.Tensor):
235
265
Tensor containing the updated weights.
236
266
"""
237
267
dtype , shape = str (weights .dtype ), tuple (weights .shape )
238
- url = f"http:// { self .host } : { self . server_port } /update_named_param/"
268
+ url = f"{ self .base_url } /update_named_param/"
239
269
response = self .session .post (url , json = {"name" : name , "dtype" : dtype , "shape" : shape })
240
270
if response .status_code != 200 :
241
271
raise Exception (f"Request failed: { response .status_code } , { response .text } " )
@@ -260,7 +290,7 @@ def reset_prefix_cache(self):
260
290
"""
261
291
Resets the prefix cache for the model.
262
292
"""
263
- url = f"http:// { self .host } : { self . server_port } /reset_prefix_cache/"
293
+ url = f"{ self .base_url } /reset_prefix_cache/"
264
294
response = self .session .post (url )
265
295
if response .status_code != 200 :
266
296
raise Exception (f"Request failed: { response .status_code } , { response .text } " )
@@ -269,7 +299,7 @@ def close_communicator(self):
269
299
"""
270
300
Closes the weight update group and cleans up the communication group.
271
301
"""
272
- url = f"http:// { self .host } : { self . server_port } /close_communicator/"
302
+ url = f"{ self .base_url } /close_communicator/"
273
303
274
304
try :
275
305
response = self .session .post (url )
0 commit comments