-
Notifications
You must be signed in to change notification settings - Fork 2.1k
🏰 [vllm] Support base_url
parameter for vLLM client initialization
#3324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
0bb3624
4d6c670
a82665e
5df0d54
4795c7d
e18b6be
7959ed5
36a7e91
9dddd65
168ce1c
cc3d093
0841060
6e71207
879aff3
7dd62b8
e578776
930c7a6
4a96303
3cf1ee0
ca72e63
fe3c9eb
c3bffcd
f4ff4c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,10 @@ | |
|
||
import atexit | ||
import logging | ||
import socket | ||
import time | ||
from typing import Optional | ||
from urllib.parse import urlparse | ||
|
||
import torch | ||
from torch import nn | ||
|
@@ -47,10 +49,13 @@ class VLLMClient: | |
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`. | ||
Args: | ||
base_url (`str` or `None`, *optional*, defaults to `None`): | ||
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `host` and `server_port` are | ||
ignored. | ||
host (`str`, *optional*, defaults to `"0.0.0.0"`): | ||
IP address of the vLLM server. | ||
IP address of the vLLM server. Ignored if `base_url` is provided. | ||
server_port (`int`, *optional*, defaults to `8000`): | ||
Port number of the vLLM server. | ||
Port number of the vLLM server. Ignored if `base_url` is provided. | ||
group_port (`int`, *optional*, defaults to `51216`): | ||
Port number for the weight update group. | ||
connection_timeout (`float`, *optional*, defaults to `0.0`): | ||
|
@@ -81,19 +86,42 @@ class VLLMClient: | |
>>> client.init_communicator() | ||
>>> client.update_model_params(model) | ||
``` | ||
There are several ways to initialize the client: | ||
```python | ||
VLLMClient(base_url="http://localhost:8000") | ||
VLLMClient(base_url="http://192.168.1.100:8000") | ||
VLLMClient(host="localhost", server_port=8000) | ||
VLLMClient(host="192.168.1.100", server_port=8000) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0 | ||
self, | ||
base_url: Optional[str] = None, | ||
host: str = "0.0.0.0", | ||
server_port: int = 8000, | ||
group_port: int = 51216, | ||
connection_timeout: float = 0.0, | ||
): | ||
if not is_requests_available(): | ||
raise ImportError("requests is not installed. Please install it with `pip install requests`.") | ||
if not is_vllm_available(): | ||
raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.") | ||
|
||
self.session = requests.Session() | ||
self.host = host | ||
self.server_port = server_port | ||
|
||
if base_url is not None: | ||
# Parse the base_url to extract host and port | ||
parsed_url = urlparse(base_url) | ||
self.host = socket.gethostbyname(parsed_url.hostname) | ||
scheme = parsed_url.scheme or "http" | ||
self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}" | ||
else: | ||
self.host = host | ||
self.server_port = server_port | ||
self.base_url = f"http://{self.host}:{self.server_port}" | ||
self.group_port = group_port | ||
self.check_server(connection_timeout) # check server and fail after timeout | ||
|
||
|
@@ -108,7 +136,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): | |
total_timeout (`float`, *optional*, defaults to `0.0`): | ||
Total timeout duration in seconds. | ||
""" | ||
url = f"http://{self.host}:{self.server_port}/health/" | ||
url = f"{self.base_url}/health/" | ||
start_time = time.time() # Record the start time | ||
|
||
while True: | ||
|
@@ -119,11 +147,13 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): | |
elapsed_time = time.time() - start_time | ||
if elapsed_time >= total_timeout: | ||
raise ConnectionError( | ||
f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} " | ||
"seconds. Make sure the server is running by running `trl vllm-serve`." | ||
f"The vLLM server can't be reached at {self.base_url} after {total_timeout} seconds. Make " | ||
"sure the server is running by running `trl vllm-serve`." | ||
) from exc | ||
else: | ||
if response.status_code == 200: | ||
if "X-Forwarded-For" in response.headers: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we still need the ip addrees when using base url
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've updated the host to be |
||
self.host = response.headers["X-Forwarded-For"] | ||
logger.info("Server is up!") | ||
return None | ||
|
||
|
@@ -170,7 +200,7 @@ def generate( | |
`list[list[int]]`: | ||
List of lists of token IDs representing the model-generated completions for each prompt. | ||
""" | ||
url = f"http://{self.host}:{self.server_port}/generate/" | ||
url = f"{self.base_url}/generate/" | ||
response = self.session.post( | ||
url, | ||
json={ | ||
|
@@ -195,7 +225,7 @@ def init_communicator(self): | |
Initializes the weight update group in a distributed setup for model synchronization. | ||
""" | ||
# Get the world size from the server | ||
url = f"http://{self.host}:{self.server_port}/get_world_size/" | ||
url = f"{self.base_url}/get_world_size/" | ||
response = requests.get(url) | ||
if response.status_code == 200: | ||
vllm_world_size = response.json()["world_size"] | ||
|
@@ -206,7 +236,7 @@ def init_communicator(self): | |
self.rank = vllm_world_size # the client's rank is the last process | ||
|
||
# Initialize weight update group | ||
url = f"http://{self.host}:{self.server_port}/init_communicator/" | ||
url = f"{self.base_url}/init_communicator/" | ||
# In the server side, the host is set to 0.0.0.0 | ||
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size}) | ||
if response.status_code != 200: | ||
|
@@ -235,7 +265,7 @@ def update_named_param(self, name: str, weights: torch.Tensor): | |
Tensor containing the updated weights. | ||
""" | ||
dtype, shape = str(weights.dtype), tuple(weights.shape) | ||
url = f"http://{self.host}:{self.server_port}/update_named_param/" | ||
url = f"{self.base_url}/update_named_param/" | ||
response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape}) | ||
if response.status_code != 200: | ||
raise Exception(f"Request failed: {response.status_code}, {response.text}") | ||
|
@@ -260,7 +290,7 @@ def reset_prefix_cache(self): | |
""" | ||
Resets the prefix cache for the model. | ||
""" | ||
url = f"http://{self.host}:{self.server_port}/reset_prefix_cache/" | ||
url = f"{self.base_url}/reset_prefix_cache/" | ||
response = self.session.post(url) | ||
if response.status_code != 200: | ||
raise Exception(f"Request failed: {response.status_code}, {response.text}") | ||
|
@@ -269,7 +299,7 @@ def close_communicator(self): | |
""" | ||
Closes the weight update group and cleans up the communication group. | ||
""" | ||
url = f"http://{self.host}:{self.server_port}/close_communicator/" | ||
url = f"{self.base_url}/close_communicator/" | ||
|
||
try: | ||
response = self.session.post(url) | ||
|
Uh oh!
There was an error while loading. Please reload this page.