Skip to content

Commit 7678fcd

Browse files
authored
Fix the torch version parsing logic (vllm-project#15857)
1 parent 8661c02 commit 7678fcd

File tree

4 files changed

+26
-11
lines changed

4 files changed

+26
-11
lines changed

vllm/compilation/compiler_interface.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import contextlib
33
import copy
44
import hashlib
5-
import importlib.metadata
65
import os
76
from contextlib import ExitStack
87
from typing import Any, Callable, Dict, List, Optional, Tuple
@@ -11,9 +10,9 @@
1110
import torch
1211
import torch._inductor.compile_fx
1312
import torch.fx as fx
14-
from packaging.version import Version
1513

1614
from vllm.config import VllmConfig
15+
from vllm.utils import is_torch_equal_or_newer
1716

1817

1918
class CompilerInterface:
@@ -379,7 +378,7 @@ def metrics_context(self) -> contextlib.AbstractContextManager:
379378
manually setting up internal contexts. But we also rely on non-public
380379
APIs which might not provide these guarantees.
381380
"""
382-
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
381+
if is_torch_equal_or_newer("2.6"):
383382
import torch._dynamo.utils
384383
return torch._dynamo.utils.get_metrics_context()
385384
else:

vllm/compilation/inductor_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import hashlib
4-
import importlib.metadata
54
import inspect
65
import json
76
import types
87
from typing import Any, Callable, Dict, Optional, Union
98

109
import torch
11-
from packaging.version import Version
1210
from torch import fx
1311

14-
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
12+
from vllm.utils import is_torch_equal_or_newer
13+
14+
if is_torch_equal_or_newer("2.6"):
1515
from torch._inductor.custom_graph_pass import CustomGraphPass
1616
else:
1717
# CustomGraphPass is not present in 2.5 or lower, import our version

vllm/config.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import copy
55
import enum
66
import hashlib
7-
import importlib.metadata
87
import json
98
import sys
109
import warnings
@@ -18,7 +17,6 @@
1817
Optional, Protocol, Union)
1918

2019
import torch
21-
from packaging.version import Version
2220
from pydantic import BaseModel, Field, PrivateAttr
2321
from torch.distributed import ProcessGroup, ReduceOp
2422
from transformers import PretrainedConfig
@@ -40,8 +38,8 @@
4038
from vllm.transformers_utils.s3_utils import S3Model
4139
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
4240
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
43-
get_cpu_memory, get_open_port, random_uuid,
44-
resolve_obj_by_qualname)
41+
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
42+
random_uuid, resolve_obj_by_qualname)
4543

4644
if TYPE_CHECKING:
4745
from ray.util.placement_group import PlacementGroup
@@ -3285,7 +3283,7 @@ def model_post_init(self, __context: Any) -> None:
32853283
# and it is not yet a priority. RFC here:
32863284
# https://github.com/vllm-project/vllm/issues/14703
32873285

3288-
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
3286+
if is_torch_equal_or_newer("2.6"):
32893287
KEY = 'enable_auto_functionalized_v2'
32903288
if KEY not in self.inductor_compile_config:
32913289
self.inductor_compile_config[KEY] = False

vllm/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import yaml
5454
import zmq
5555
import zmq.asyncio
56+
from packaging import version
5657
from packaging.version import Version
5758
from torch.library import Library
5859
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
@@ -2580,3 +2581,20 @@ def sha256(input) -> int:
25802581
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
25812582
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
25822583
byteorder="big")
2584+
2585+
2586+
def is_torch_equal_or_newer(target: str) -> bool:
2587+
"""Check if the installed torch version is >= the target version.
2588+
2589+
Args:
2590+
target: a version string, like "2.6.0".
2591+
2592+
Returns:
2593+
Whether the condition meets.
2594+
"""
2595+
try:
2596+
torch_version = version.parse(str(torch.__version__))
2597+
return torch_version >= version.parse(target)
2598+
except Exception:
2599+
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
2600+
return Version(importlib.metadata.version('torch')) >= Version(target)

0 commit comments

Comments
 (0)