Skip to content

Commit 37cc51b

Browse files
committed
Add platform guard
1 parent 7280766 commit 37cc51b

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/models/test_registry.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch.cuda
55

66
from vllm.model_executor.models import _MODELS, ModelRegistry
7+
from vllm.platforms import current_platform
78

89
from ..utils import fork_new_process_for_each_test
910

@@ -23,7 +24,7 @@ def test_registry_imports(model_arch):
2324
def test_registry_is_multimodal(model_arch, is_mm, init_cuda):
2425
assert ModelRegistry.is_multimodal_model(model_arch) is is_mm
2526

26-
if init_cuda:
27+
if init_cuda and current_platform.is_cuda_alike():
2728
assert not torch.cuda.is_initialized()
2829

2930
ModelRegistry.resolve_model_cls(model_arch)
@@ -43,7 +44,7 @@ def test_registry_is_multimodal(model_arch, is_mm, init_cuda):
4344
def test_registry_is_pp(model_arch, is_pp, init_cuda):
4445
assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp
4546

46-
if init_cuda:
47+
if init_cuda and current_platform.is_cuda_alike():
4748
assert not torch.cuda.is_initialized()
4849

4950
ModelRegistry.resolve_model_cls(model_arch)

0 commit comments

Comments
 (0)