|
1 | 1 | import importlib
|
2 |
| -import string |
| 2 | +import pickle |
3 | 3 | import subprocess
|
4 | 4 | import sys
|
5 |
| -import uuid |
| 5 | +import tempfile |
6 | 6 | from functools import lru_cache, partial
|
7 | 7 | from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
8 | 8 |
|
| 9 | +import cloudpickle |
9 | 10 | import torch.nn as nn
|
10 | 11 |
|
11 | 12 | from vllm.logger import init_logger
|
@@ -282,36 +283,28 @@ def _check_stateless(
|
282 | 283 |
|
283 | 284 | raise
|
284 | 285 |
|
285 |
| - valid_name_characters = string.ascii_letters + string.digits + "._" |
286 |
| - if any(s not in valid_name_characters for s in mod_name): |
287 |
| - raise ValueError(f"Unsafe module name detected for {model_arch}") |
288 |
| - if any(s not in valid_name_characters for s in cls_name): |
289 |
| - raise ValueError(f"Unsafe class name detected for {model_arch}") |
290 |
| - if any(s not in valid_name_characters for s in func.__module__): |
291 |
| - raise ValueError(f"Unsafe module name detected for {func}") |
292 |
| - if any(s not in valid_name_characters for s in func.__name__): |
293 |
| - raise ValueError(f"Unsafe class name detected for {func}") |
294 |
| - |
295 |
| - err_id = uuid.uuid4() |
296 |
| - |
297 |
| - stmts = ";".join([ |
298 |
| - f"from {mod_name} import {cls_name}", |
299 |
| - f"from {func.__module__} import {func.__name__}", |
300 |
| - f"assert {func.__name__}({cls_name}), '{err_id}'", |
301 |
| - ]) |
302 |
| - |
303 |
| - result = subprocess.run([sys.executable, "-c", stmts], |
304 |
| - capture_output=True) |
305 |
| - |
306 |
| - if result.returncode != 0: |
307 |
| - err_lines = [line.decode() for line in result.stderr.splitlines()] |
308 |
| - if err_lines and err_lines[-1] != f"AssertionError: {err_id}": |
309 |
| - err_str = "\n".join(err_lines) |
310 |
| - raise RuntimeError( |
311 |
| - "An unexpected error occurred while importing the model in " |
312 |
| - f"another process. Error log:\n{err_str}") |
313 |
| - |
314 |
| - return result.returncode == 0 |
| 286 | + with tempfile.NamedTemporaryFile() as output_file: |
| 287 | + # `cloudpickle` allows pickling lambda functions directly |
| 288 | + input_bytes = cloudpickle.dumps( |
| 289 | + (mod_name, cls_name, func, output_file.name)) |
| 290 | + # cannot use `sys.executable __file__` here because the script |
| 291 | + # contains relative imports |
| 292 | + returned = subprocess.run( |
| 293 | + [sys.executable, "-m", "vllm.model_executor.models.registry"], |
| 294 | + input=input_bytes, |
| 295 | + capture_output=True) |
| 296 | + |
| 297 | + # check if the subprocess is successful |
| 298 | + try: |
| 299 | + returned.check_returncode() |
| 300 | + except Exception as e: |
| 301 | + # wrap raised exception to provide more information |
| 302 | + raise RuntimeError(f"Error happened when testing " |
| 303 | + f"model support for{mod_name}.{cls_name}:\n" |
| 304 | + f"{returned.stderr.decode()}") from e |
| 305 | + with open(output_file.name, "rb") as f: |
| 306 | + result = pickle.load(f) |
| 307 | + return result |
315 | 308 |
|
316 | 309 | @staticmethod
|
317 | 310 | def is_text_generation_model(architectures: Union[str, List[str]]) -> bool:
|
@@ -364,3 +357,13 @@ def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
|
364 | 357 | default=False)
|
365 | 358 |
|
366 | 359 | return any(is_pp(arch) for arch in architectures)
|
| 360 | + |
| 361 | + |
| 362 | +if __name__ == "__main__": |
| 363 | + (mod_name, cls_name, func, |
| 364 | + output_file) = pickle.loads(sys.stdin.buffer.read()) |
| 365 | + mod = importlib.import_module(mod_name) |
| 366 | + klass = getattr(mod, cls_name) |
| 367 | + result = func(klass) |
| 368 | + with open(output_file, "wb") as f: |
| 369 | + f.write(pickle.dumps(result)) |
0 commit comments