Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 3b05677

Browse files
Add tgis tools
Signed-off-by: Rafael Vasquez <[email protected]> Co-authored-by: Prashant Gupta <[email protected]>
1 parent e0b5a8d commit 3b05677

File tree

5 files changed

+539
-0
lines changed

5 files changed

+539
-0
lines changed

tests/tgis/__init__.py

Whitespace-only changes.

tests/tgis/test_hub.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
from huggingface_hub.utils import LocalEntryNotFoundError
5+
6+
from vllm.tgis_utils.hub import (convert_files, download_weights, weight_files,
7+
weight_hub_files)
8+
9+
10+
def test_convert_files():
11+
model_id = "bigscience/bloom-560m"
12+
local_pt_files = download_weights(model_id, extension=".bin")
13+
local_pt_files = [Path(p) for p in local_pt_files]
14+
local_st_files = [
15+
p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors"
16+
for p in local_pt_files
17+
]
18+
convert_files(local_pt_files, local_st_files, discard_names=[])
19+
20+
found_st_files = weight_files(model_id)
21+
22+
assert all([str(p) in found_st_files for p in local_st_files])
23+
24+
25+
def test_weight_hub_files():
26+
filenames = weight_hub_files("bigscience/bloom-560m")
27+
assert filenames == ["model.safetensors"]
28+
29+
30+
def test_weight_hub_files_llm():
31+
filenames = weight_hub_files("bigscience/bloom")
32+
assert filenames == [
33+
f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)
34+
]
35+
36+
37+
def test_weight_hub_files_empty():
38+
filenames = weight_hub_files("bigscience/bloom", ".errors")
39+
assert filenames == []
40+
41+
42+
def test_download_weights():
43+
files = download_weights("bigscience/bloom-560m")
44+
local_files = weight_files("bigscience/bloom-560m")
45+
assert files == local_files
46+
47+
48+
def test_weight_files_error():
49+
with pytest.raises(LocalEntryNotFoundError):
50+
weight_files("bert-base-uncased")

vllm/entrypoints/openai/api_server.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import vllm.envs as envs
1919
from vllm.engine.arg_utils import AsyncEngineArgs
2020
from vllm.engine.async_llm_engine import AsyncLLMEngine
21+
from vllm.entrypoints.grpc.grpc_server import start_grpc_server
2122
from vllm.entrypoints.openai.cli_args import make_arg_parser
2223
# yapf conflicts with isort for this block
2324
# yapf: disable
@@ -34,6 +35,7 @@
3435
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
3536
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
3637
from vllm.logger import init_logger
38+
from vllm.tgis_utils.args import add_tgis_args, postprocess_tgis_args
3739
from vllm.usage.usage_lib import UsageContext
3840
from vllm.utils import FlexibleArgumentParser
3941
from vllm.version import __version__ as VLLM_VERSION
@@ -46,6 +48,7 @@
4648
openai_serving_chat: OpenAIServingChat
4749
openai_serving_completion: OpenAIServingCompletion
4850
openai_serving_embedding: OpenAIServingEmbedding
51+
async_llm_engine: AsyncLLMEngine
4952

5053
logger = init_logger('vllm.entrypoints.openai.api_server')
5154

@@ -65,8 +68,15 @@ async def _force_log():
6568
_running_tasks.add(task)
6669
task.add_done_callback(_running_tasks.remove)
6770

71+
grpc_server = await start_grpc_server(async_llm_engine, args)
72+
6873
yield
6974

75+
logger.info("Gracefully stopping gRPC server")
76+
await grpc_server.stop(30) #TODO configurable grace
77+
await grpc_server.wait_for_termination()
78+
logger.info("gRPC server stopped")
79+
7080

7181
router = APIRouter()
7282

@@ -220,6 +230,16 @@ def run_server(args, llm_engine=None):
220230
global engine, engine_args
221231

222232
engine_args = AsyncEngineArgs.from_cli_args(args)
233+
234+
# Enforce pixel values as image input type for vision language models
235+
# when serving with API server
236+
if engine_args.image_input_type is not None and \
237+
engine_args.image_input_type.upper() != "PIXEL_VALUES":
238+
raise ValueError(
239+
f"Invalid image_input_type: {engine_args.image_input_type}. "
240+
"Only --image-input-type 'pixel_values' is supported for serving "
241+
"vision language models with the vLLM API server.")
242+
223243
engine = (llm_engine
224244
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
225245
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
@@ -241,6 +261,7 @@ def run_server(args, llm_engine=None):
241261
global openai_serving_chat
242262
global openai_serving_completion
243263
global openai_serving_embedding
264+
global async_llm_engine
244265

245266
openai_serving_chat = OpenAIServingChat(engine, model_config,
246267
served_model_names,
@@ -252,6 +273,11 @@ def run_server(args, llm_engine=None):
252273
args.prompt_adapters)
253274
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
254275
served_model_names)
276+
277+
# 🌶️🌶️🌶️ Sets the engine for the TGIS gRPC server.
278+
# Do not delete on merge conflicts!
279+
async_llm_engine = engine
280+
255281
app.root_path = args.root_path
256282

257283
logger.info("Available routes are:")
@@ -278,5 +304,8 @@ def run_server(args, llm_engine=None):
278304
parser = FlexibleArgumentParser(
279305
description="vLLM OpenAI-Compatible RESTful API server.")
280306
parser = make_arg_parser(parser)
307+
parser = add_tgis_args(parser)
281308
args = parser.parse_args()
309+
args = postprocess_tgis_args(args)
310+
282311
run_server(args)

vllm/scripts.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import signal
55
import sys
6+
from pathlib import Path
67
from typing import Optional
78

89
from openai import OpenAI
@@ -49,6 +50,19 @@ def interactive_cli(args: argparse.Namespace) -> None:
4950
chat(args.system_prompt, model_name, openai_client)
5051

5152

53+
def tgis_cli(args: argparse.Namespace) -> None:
54+
registrer_signal_handlers()
55+
56+
if args.command == "download-weights":
57+
download_weights(args.model_name, args.revision, args.token,
58+
args.extension, args.auto_convert)
59+
elif args.command == "convert-to-safetensors":
60+
convert_to_safetensors(args.model_name, args.revision)
61+
elif args.command == "convert-to-fast-tokenizer":
62+
convert_to_fast_tokenizer(args.model_name, args.revision,
63+
args.output_path)
64+
65+
5266
def complete(model_name: str, client: OpenAI) -> None:
5367
print("Please enter prompt to complete:")
5468
while True:
@@ -82,6 +96,151 @@ def chat(system_prompt: Optional[str], model_name: str,
8296
print(output)
8397

8498

99+
def download_weights(
100+
model_name: str,
101+
revision: Optional[str] = None,
102+
token: Optional[str] = None,
103+
extension: str = ".safetensors",
104+
auto_convert: bool = True,
105+
) -> None:
106+
from vllm.tgis_utils import hub
107+
108+
print(extension)
109+
meta_exts = [".json", ".py", ".model", ".md"]
110+
111+
extensions = extension.split(",")
112+
113+
if len(extensions) == 1 and extensions[0] not in meta_exts:
114+
extensions.extend(meta_exts)
115+
116+
files = hub.download_weights(model_name,
117+
extensions,
118+
revision=revision,
119+
auth_token=token)
120+
121+
if auto_convert and ".safetensors" in extensions:
122+
if not hub.local_weight_files(hub.get_model_path(model_name, revision),
123+
".safetensors"):
124+
if ".bin" not in extensions:
125+
print(".safetensors weights not found, \
126+
downloading pytorch weights to convert...")
127+
hub.download_weights(model_name,
128+
".bin",
129+
revision=revision,
130+
auth_token=token)
131+
132+
print(".safetensors weights not found, \
133+
converting from pytorch weights...")
134+
convert_to_safetensors(model_name, revision)
135+
elif not any(f.endswith(".safetensors") for f in files):
136+
print(".safetensors weights not found on hub, \
137+
but were found locally. Remove them first to re-convert")
138+
if auto_convert:
139+
convert_to_fast_tokenizer(model_name, revision)
140+
141+
142+
def convert_to_safetensors(
143+
model_name: str,
144+
revision: Optional[str] = None,
145+
):
146+
from vllm.tgis_utils import hub
147+
148+
# Get local pytorch file paths
149+
model_path = hub.get_model_path(model_name, revision)
150+
local_pt_files = hub.local_weight_files(model_path, ".bin")
151+
local_pt_index_files = hub.local_index_files(model_path, ".bin")
152+
if len(local_pt_index_files) > 1:
153+
print(
154+
f"Found more than one .bin.index.json file: {local_pt_index_files}"
155+
)
156+
return
157+
if not local_pt_files:
158+
print("No pytorch .bin files found to convert")
159+
return
160+
161+
local_pt_files = [Path(f) for f in local_pt_files]
162+
local_pt_index_file = local_pt_index_files[
163+
0] if local_pt_index_files else None
164+
165+
# Safetensors final filenames
166+
local_st_files = [
167+
p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors"
168+
for p in local_pt_files
169+
]
170+
171+
if any(os.path.exists(p) for p in local_st_files):
172+
print("Existing .safetensors weights found, \
173+
remove them first to reconvert")
174+
return
175+
176+
try:
177+
import transformers
178+
179+
config = transformers.AutoConfig.from_pretrained(
180+
model_name,
181+
revision=revision,
182+
)
183+
architecture = config.architectures[0]
184+
185+
class_ = getattr(transformers, architecture)
186+
187+
# Name for this variable depends on transformers version
188+
discard_names = getattr(class_, "_tied_weights_keys", [])
189+
discard_names.extend(
190+
getattr(class_, "_keys_to_ignore_on_load_missing", []))
191+
192+
except Exception:
193+
discard_names = []
194+
195+
if local_pt_index_file:
196+
local_pt_index_file = Path(local_pt_index_file)
197+
st_prefix = local_pt_index_file.stem.removeprefix(
198+
"pytorch_").removesuffix(".bin.index")
199+
local_st_index_file = (local_pt_index_file.parent /
200+
f"{st_prefix}.safetensors.index.json")
201+
202+
if os.path.exists(local_st_index_file):
203+
print("Existing .safetensors.index.json file found, \
204+
remove it first to reconvert")
205+
return
206+
207+
hub.convert_index_file(local_pt_index_file, local_st_index_file,
208+
local_pt_files, local_st_files)
209+
210+
# Convert pytorch weights to safetensors
211+
hub.convert_files(local_pt_files, local_st_files, discard_names)
212+
213+
214+
def convert_to_fast_tokenizer(
215+
model_name: str,
216+
revision: Optional[str] = None,
217+
output_path: Optional[str] = None,
218+
):
219+
from vllm.tgis_utils import hub
220+
221+
# Check for existing "tokenizer.json"
222+
model_path = hub.get_model_path(model_name, revision)
223+
224+
if os.path.exists(os.path.join(model_path, "tokenizer.json")):
225+
print(f"Model {model_name} already has a fast tokenizer")
226+
return
227+
228+
if output_path is not None:
229+
if not os.path.isdir(output_path):
230+
print(f"Output path {output_path} must exist and be a directory")
231+
return
232+
else:
233+
output_path = model_path
234+
235+
import transformers
236+
237+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name,
238+
revision=revision)
239+
tokenizer.save_pretrained(output_path)
240+
241+
print(f"Saved tokenizer to {output_path}")
242+
243+
85244
def _add_query_options(
86245
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
87246
parser.add_argument(
@@ -142,6 +301,37 @@ def main():
142301
"used for models that support system prompts."))
143302
chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat")
144303

304+
download_weights_parser = subparsers.add_parser(
305+
"download-weights",
306+
help=("Download the weights of a given model"),
307+
usage="vllm download-weights <model_name> [options]")
308+
download_weights_parser.add_argument("model_name")
309+
download_weights_parser.add_argument("--revision")
310+
download_weights_parser.add_argument("--token")
311+
download_weights_parser.add_argument("--extension", default=".safetensors")
312+
download_weights_parser.add_argument("--auto_convert", default=True)
313+
download_weights_parser.set_defaults(dispatch_function=tgis_cli,
314+
command="download-weights")
315+
316+
convert_to_safetensors_parser = subparsers.add_parser(
317+
"convert-to-safetensors",
318+
help=("Convert model weights to safetensors"),
319+
usage="vllm convert-to-safetensors <model_name> [options]")
320+
convert_to_safetensors_parser.add_argument("model_name")
321+
convert_to_safetensors_parser.add_argument("--revision")
322+
convert_to_safetensors_parser.set_defaults(
323+
dispatch_function=tgis_cli, command="convert-to-safetensors")
324+
325+
convert_to_fast_tokenizer_parser = subparsers.add_parser(
326+
"convert-to-fast-tokenizer",
327+
help=("Convert to fast tokenizer"),
328+
usage="vllm convert-to-fast-tokenizer <model_name> [options]")
329+
convert_to_fast_tokenizer_parser.add_argument("model_name")
330+
convert_to_fast_tokenizer_parser.add_argument("--revision")
331+
convert_to_fast_tokenizer_parser.add_argument("--output_path")
332+
convert_to_fast_tokenizer_parser.set_defaults(
333+
dispatch_function=tgis_cli, command="convert-to-fast-tokenizer")
334+
145335
args = parser.parse_args()
146336
# One of the sub commands should be executed.
147337
if hasattr(args, "dispatch_function"):

0 commit comments

Comments
 (0)