Skip to content

Commit bd094b4

Browse files
youkaichaoAlvant
authored andcommitted
[misc] improve model support check in another process (vllm-project#9208)
Signed-off-by: Alvant <[email protected]>
1 parent 59a1710 commit bd094b4

File tree

2 files changed

+36
-32
lines changed

2 files changed

+36
-32
lines changed

docs/requirements-docs.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ sphinx-copybutton==0.5.2
44
myst-parser==2.0.0
55
sphinx-argparse==0.4.0
66
msgspec
7+
cloudpickle
78

89
# packages to install to build the documentation
910
pydantic >= 2.8

vllm/model_executor/models/registry.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import importlib
2-
import string
2+
import pickle
33
import subprocess
44
import sys
5-
import uuid
5+
import tempfile
66
from functools import lru_cache, partial
77
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
88

9+
import cloudpickle
910
import torch.nn as nn
1011

1112
from vllm.logger import init_logger
@@ -282,36 +283,28 @@ def _check_stateless(
282283

283284
raise
284285

285-
valid_name_characters = string.ascii_letters + string.digits + "._"
286-
if any(s not in valid_name_characters for s in mod_name):
287-
raise ValueError(f"Unsafe module name detected for {model_arch}")
288-
if any(s not in valid_name_characters for s in cls_name):
289-
raise ValueError(f"Unsafe class name detected for {model_arch}")
290-
if any(s not in valid_name_characters for s in func.__module__):
291-
raise ValueError(f"Unsafe module name detected for {func}")
292-
if any(s not in valid_name_characters for s in func.__name__):
293-
raise ValueError(f"Unsafe class name detected for {func}")
294-
295-
err_id = uuid.uuid4()
296-
297-
stmts = ";".join([
298-
f"from {mod_name} import {cls_name}",
299-
f"from {func.__module__} import {func.__name__}",
300-
f"assert {func.__name__}({cls_name}), '{err_id}'",
301-
])
302-
303-
result = subprocess.run([sys.executable, "-c", stmts],
304-
capture_output=True)
305-
306-
if result.returncode != 0:
307-
err_lines = [line.decode() for line in result.stderr.splitlines()]
308-
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
309-
err_str = "\n".join(err_lines)
310-
raise RuntimeError(
311-
"An unexpected error occurred while importing the model in "
312-
f"another process. Error log:\n{err_str}")
313-
314-
return result.returncode == 0
286+
with tempfile.NamedTemporaryFile() as output_file:
287+
# `cloudpickle` allows pickling lambda functions directly
288+
input_bytes = cloudpickle.dumps(
289+
(mod_name, cls_name, func, output_file.name))
290+
# cannot use `sys.executable __file__` here because the script
291+
# contains relative imports
292+
returned = subprocess.run(
293+
[sys.executable, "-m", "vllm.model_executor.models.registry"],
294+
input=input_bytes,
295+
capture_output=True)
296+
297+
# check if the subprocess is successful
298+
try:
299+
returned.check_returncode()
300+
except Exception as e:
301+
# wrap raised exception to provide more information
302+
raise RuntimeError(f"Error happened when testing "
303+
f"model support for{mod_name}.{cls_name}:\n"
304+
f"{returned.stderr.decode()}") from e
305+
with open(output_file.name, "rb") as f:
306+
result = pickle.load(f)
307+
return result
315308

316309
@staticmethod
317310
def is_text_generation_model(architectures: Union[str, List[str]]) -> bool:
@@ -364,3 +357,13 @@ def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
364357
default=False)
365358

366359
return any(is_pp(arch) for arch in architectures)
360+
361+
362+
if __name__ == "__main__":
363+
(mod_name, cls_name, func,
364+
output_file) = pickle.loads(sys.stdin.buffer.read())
365+
mod = importlib.import_module(mod_name)
366+
klass = getattr(mod, cls_name)
367+
result = func(klass)
368+
with open(output_file, "wb") as f:
369+
f.write(pickle.dumps(result))

0 commit comments

Comments
 (0)