Skip to content

Commit 7280766

Browse files
committed
Test CUDA initialization
1 parent dcc2a49 commit 7280766

File tree

2 files changed

+48
-14
lines changed

2 files changed

+48
-14
lines changed

tests/models/test_registry.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,54 @@
1+
import warnings
2+
13
import pytest
4+
import torch.cuda
25

36
from vllm.model_executor.models import _MODELS, ModelRegistry
47

8+
from ..utils import fork_new_process_for_each_test
9+
510

6-
@pytest.mark.parametrize("model_cls", _MODELS)
7-
def test_registry_imports(model_cls):
11+
@pytest.mark.parametrize("model_arch", _MODELS)
12+
def test_registry_imports(model_arch):
813
# Ensure all model classes can be imported successfully
9-
ModelRegistry.resolve_model_cls([model_cls])
14+
ModelRegistry.resolve_model_cls(model_arch)
1015

1116

12-
@pytest.mark.parametrize("model_cls,is_mm", [
13-
("LlamaForCausalLM", False),
14-
("MllamaForConditionalGeneration", True),
17+
@fork_new_process_for_each_test
18+
@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [
19+
("LlamaForCausalLM", False, False),
20+
("MllamaForConditionalGeneration", True, False),
21+
("LlavaForConditionalGeneration", True, True),
1522
])
16-
def test_registry_is_multimodal(model_cls, is_mm):
17-
assert ModelRegistry.is_multimodal_model(model_cls) is is_mm
23+
def test_registry_is_multimodal(model_arch, is_mm, init_cuda):
24+
assert ModelRegistry.is_multimodal_model(model_arch) is is_mm
25+
26+
if init_cuda:
27+
assert not torch.cuda.is_initialized()
1828

29+
ModelRegistry.resolve_model_cls(model_arch)
30+
if not torch.cuda.is_initialized():
31+
warnings.warn(
32+
"This model no longer initializes CUDA on import. "
33+
"Please test using a different model.",
34+
stacklevel=2)
1935

20-
@pytest.mark.parametrize("model_cls,is_pp", [
21-
("MLPSpeculatorPreTrainedModel", False),
22-
("DeepseekV2ForCausalLM", True),
36+
37+
@fork_new_process_for_each_test
38+
@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [
39+
("MLPSpeculatorPreTrainedModel", False, False),
40+
("DeepseekV2ForCausalLM", True, False),
41+
("Qwen2VLForConditionalGeneration", True, True),
2342
])
24-
def test_registry_is_pp(model_cls, is_pp):
25-
assert ModelRegistry.is_pp_supported_model(model_cls) is is_pp
43+
def test_registry_is_pp(model_arch, is_pp, init_cuda):
44+
assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp
45+
46+
if init_cuda:
47+
assert not torch.cuda.is_initialized()
48+
49+
ModelRegistry.resolve_model_cls(model_arch)
50+
if not torch.cuda.is_initialized():
51+
warnings.warn(
52+
"This model no longer initializes CUDA on import. "
53+
"Please test using a different model.",
54+
stacklevel=2)

vllm/model_executor/models/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,12 @@ def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
221221

222222
@staticmethod
223223
def resolve_model_cls(
224-
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
224+
architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
225+
if isinstance(architectures, str):
226+
architectures = [architectures]
227+
if not architectures:
228+
logger.warning("No model architectures are specified")
229+
225230
for arch in architectures:
226231
model_cls = ModelRegistry._try_load_model_cls(arch)
227232
if model_cls is not None:

0 commit comments

Comments
 (0)