Skip to content

🏰 [vllm] Support base_url parameter for vLLM client initialization #3324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0bb3624
refactor
re-imagined Apr 27, 2025
4d6c670
refactor
re-imagined Apr 28, 2025
a82665e
refactor
re-imagined Apr 28, 2025
5df0d54
refactor
re-imagined Apr 28, 2025
4795c7d
Merge branch 'main' into vllm_client_custom_url
re-imagined Apr 28, 2025
e18b6be
fix
re-imagined Apr 30, 2025
7959ed5
fix host
re-imagined May 4, 2025
36a7e91
Merge branch 'main' into vllm_client_custom_url
re-imagined May 6, 2025
9dddd65
fix(vllm_client): update host value using X-Forwarded-For header or s…
re-imagined May 6, 2025
168ce1c
Merge branch 'main' into vllm_client_custom_url
re-imagined May 7, 2025
cc3d093
fix(vllm_client): enhance server connectivity and IP retrieval
re-imagined May 8, 2025
0841060
Merge branch 'main' into vllm_client_custom_url
re-imagined May 8, 2025
6e71207
Merge branch 'main' into vllm_client_custom_url
re-imagined May 8, 2025
879aff3
Merge branch 'main' into vllm_client_custom_url
re-imagined May 8, 2025
7dd62b8
Merge branch 'vllm_client_custom_url' of github.com:re-imagined/trl i…
re-imagined May 8, 2025
e578776
Merge branch 'main' into vllm_client_custom_url
re-imagined May 9, 2025
930c7a6
Merge branch 'main' into vllm_client_custom_url
re-imagined May 16, 2025
4a96303
Merge branch 'main' into vllm_client_custom_url
qgallouedec May 20, 2025
3cf1ee0
update test
qgallouedec May 20, 2025
ca72e63
get host when url is parsed
qgallouedec May 20, 2025
fe3c9eb
format
qgallouedec May 20, 2025
c3bffcd
Merge branch 'main' into vllm_client_custom_url
re-imagined May 24, 2025
f4ff4c6
Merge branch 'main' into vllm_client_custom_url
qgallouedec May 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions tests/test_vllm_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,85 @@ def tearDownClass(cls):
cls.server_process.wait()


# Same as above but using base_url to instantiate the client.
@pytest.mark.slow
@require_torch_multi_accelerator
class TestVLLMClientServerBaseURL(unittest.TestCase):
model_id = "Qwen/Qwen2.5-1.5B"

@classmethod
def setUpClass(cls):
# We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1"
env = os.environ.copy()
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1

# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)

# Initialize the client
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=240)
cls.client.init_communicator()

def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)

# Check that the output is a list
self.assertIsInstance(outputs, list)

# Check that the number of generated sequences is equal to the number of prompts
self.assertEqual(len(outputs), len(prompts))

# Check that the generated sequences are lists of integers
for seq in outputs:
self.assertTrue(all(isinstance(tok, int) for tok in seq))

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)

# Check that the output is a list
self.assertIsInstance(outputs, list)

# Check that the number of generated sequences is 2 times the number of prompts
self.assertEqual(len(outputs), 2 * len(prompts))

# Check that the generated sequences are lists of integers
for seq in outputs:
self.assertTrue(all(isinstance(tok, int) for tok in seq))

# Check that the length of the generated sequences is less than or equal to 32
for seq in outputs:
self.assertLessEqual(len(seq), 32)

def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
self.client.update_model_params(model)

def test_reset_prefix_cache(self):
# Test resetting the prefix cache
self.client.reset_prefix_cache()

@classmethod
def tearDownClass(cls):
super().tearDownClass()

# Close the client
cls.client.close_communicator()

# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
parent = psutil.Process(cls.server_process.pid)
children = parent.children(recursive=True)
for child in children:
child.send_signal(signal.SIGTERM)
cls.server_process.terminate()
cls.server_process.wait()


@pytest.mark.slow
@require_3_accelerators
class TestVLLMClientServerTP(unittest.TestCase):
Expand Down
58 changes: 44 additions & 14 deletions trl/extras/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import atexit
import logging
import socket
import time
from typing import Optional
from urllib.parse import urlparse

import torch
from torch import nn
Expand Down Expand Up @@ -47,10 +49,13 @@ class VLLMClient:
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
Args:
base_url (`str` or `None`, *optional*, defaults to `None`):
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `host` and `server_port` are
ignored.
host (`str`, *optional*, defaults to `"0.0.0.0"`):
IP address of the vLLM server.
IP address of the vLLM server. Ignored if `base_url` is provided.
server_port (`int`, *optional*, defaults to `8000`):
Port number of the vLLM server.
Port number of the vLLM server. Ignored if `base_url` is provided.
group_port (`int`, *optional*, defaults to `51216`):
Port number for the weight update group.
connection_timeout (`float`, *optional*, defaults to `0.0`):
Expand Down Expand Up @@ -81,19 +86,42 @@ class VLLMClient:
>>> client.init_communicator()
>>> client.update_model_params(model)
```
There are several ways to initialize the client:
```python
VLLMClient(base_url="http://localhost:8000")
VLLMClient(base_url="http://192.168.1.100:8000")
VLLMClient(host="localhost", server_port=8000)
VLLMClient(host="192.168.1.100", server_port=8000)
```
"""

def __init__(
self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0
self,
base_url: Optional[str] = None,
host: str = "0.0.0.0",
server_port: int = 8000,
group_port: int = 51216,
connection_timeout: float = 0.0,
):
if not is_requests_available():
raise ImportError("requests is not installed. Please install it with `pip install requests`.")
if not is_vllm_available():
raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.")

self.session = requests.Session()
self.host = host
self.server_port = server_port

if base_url is not None:
# Parse the base_url to extract host and port
parsed_url = urlparse(base_url)
self.host = socket.gethostbyname(parsed_url.hostname)
scheme = parsed_url.scheme or "http"
self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}"
else:
self.host = host
self.server_port = server_port
self.base_url = f"http://{self.host}:{self.server_port}"
self.group_port = group_port
self.check_server(connection_timeout) # check server and fail after timeout

Expand All @@ -108,7 +136,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
total_timeout (`float`, *optional*, defaults to `0.0`):
Total timeout duration in seconds.
"""
url = f"http://{self.host}:{self.server_port}/health/"
url = f"{self.base_url}/health/"
start_time = time.time() # Record the start time

while True:
Expand All @@ -119,11 +147,13 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
elapsed_time = time.time() - start_time
if elapsed_time >= total_timeout:
raise ConnectionError(
f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
"seconds. Make sure the server is running by running `trl vllm-serve`."
f"The vLLM server can't be reached at {self.base_url} after {total_timeout} seconds. Make "
"sure the server is running by running `trl vllm-serve`."
) from exc
else:
if response.status_code == 200:
if "X-Forwarded-For" in response.headers:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still need the ip addrees when using base url

  1. An HTTP proxy can add a header containing the backend IP so that the client can read it from the response. The most common header for passing client IPs is X-Forwarded-For
  2. get the ip via a request

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the host to be self.host = socket.gethostbyname(parsed_url.hostname) during the init. It's simpler. It works locally, but I'm curious to know if it works for you.

self.host = response.headers["X-Forwarded-For"]
logger.info("Server is up!")
return None

Expand Down Expand Up @@ -170,7 +200,7 @@ def generate(
`list[list[int]]`:
List of lists of token IDs representing the model-generated completions for each prompt.
"""
url = f"http://{self.host}:{self.server_port}/generate/"
url = f"{self.base_url}/generate/"
response = self.session.post(
url,
json={
Expand All @@ -195,7 +225,7 @@ def init_communicator(self):
Initializes the weight update group in a distributed setup for model synchronization.
"""
# Get the world size from the server
url = f"http://{self.host}:{self.server_port}/get_world_size/"
url = f"{self.base_url}/get_world_size/"
response = requests.get(url)
if response.status_code == 200:
vllm_world_size = response.json()["world_size"]
Expand All @@ -206,7 +236,7 @@ def init_communicator(self):
self.rank = vllm_world_size # the client's rank is the last process

# Initialize weight update group
url = f"http://{self.host}:{self.server_port}/init_communicator/"
url = f"{self.base_url}/init_communicator/"
# In the server side, the host is set to 0.0.0.0
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size})
if response.status_code != 200:
Expand Down Expand Up @@ -235,7 +265,7 @@ def update_named_param(self, name: str, weights: torch.Tensor):
Tensor containing the updated weights.
"""
dtype, shape = str(weights.dtype), tuple(weights.shape)
url = f"http://{self.host}:{self.server_port}/update_named_param/"
url = f"{self.base_url}/update_named_param/"
response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape})
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
Expand All @@ -260,7 +290,7 @@ def reset_prefix_cache(self):
"""
Resets the prefix cache for the model.
"""
url = f"http://{self.host}:{self.server_port}/reset_prefix_cache/"
url = f"{self.base_url}/reset_prefix_cache/"
response = self.session.post(url)
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
Expand All @@ -269,7 +299,7 @@ def close_communicator(self):
"""
Closes the weight update group and cleans up the communication group.
"""
url = f"http://{self.host}:{self.server_port}/close_communicator/"
url = f"{self.base_url}/close_communicator/"

try:
response = self.session.post(url)
Expand Down
19 changes: 14 additions & 5 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ class GRPOConfig(TrainingArguments):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.

> Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)

vllm_server_base_url (`str` or `None`, *optional*, defaults to `None`):
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and
`vllm_server_port` are ignored.
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
Host of the vLLM server to connect to.
Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
vllm_server_port (`int`, *optional*, defaults to `8000`):
Port of the vLLM server to connect to.
Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
timeout, a `ConnectionError` is raised.
Expand Down Expand Up @@ -318,6 +320,13 @@ class GRPOConfig(TrainingArguments):
"generation instead of the default model.generate(). Requires `vllm` to be installed."
},
)
vllm_server_base_url: Optional[str] = field(
default=None,
metadata={
"help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` "
"and `vllm_server_port` are ignored."
},
)
vllm_mode: str = field(
default="server",
metadata={
Expand All @@ -336,11 +345,11 @@ class GRPOConfig(TrainingArguments):
# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
vllm_server_host: str = field(
default="0.0.0.0",
metadata={"help": "Host of the vLLM server to connect to."},
metadata={"help": "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
)
vllm_server_port: int = field(
default=8000,
metadata={"help": "Port of the vLLM server to connect to."},
metadata={"help": "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
)
vllm_server_timeout: float = field(
default=240.0,
Expand Down
8 changes: 5 additions & 3 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,11 @@ def data_collator(features): # No data collation is needed in GRPO
)

if self.vllm_mode == "server" and self.accelerator.is_main_process:
self.vllm_client = VLLMClient(
args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout
)
if args.vllm_server_base_url is not None:
base_url = args.vllm_server_base_url
else:
base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
self.vllm_client.init_communicator()

elif self.vllm_mode == "colocate":
Expand Down
Loading