Skip to content

Commit 19324d6

Browse files
authored
Update deprecated type hinting in vllm/compilation (#18072)
Signed-off-by: Harry Mellor <[email protected]>
1 parent fc407a1 commit 19324d6

13 files changed

+70
-69
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ exclude = [
7474
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
7575
"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"]
7676
"vllm/attention/**/*.py" = ["UP006", "UP035"]
77-
"vllm/compilation/**/*.py" = ["UP006", "UP035"]
7877
"vllm/core/**/*.py" = ["UP006", "UP035"]
7978
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
8079
"vllm/distributed/**/*.py" = ["UP006", "UP035"]

vllm/compilation/backends.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import os
66
import pprint
77
import time
8+
from collections.abc import Sequence
89
from contextlib import ExitStack
9-
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
10+
from typing import Any, Callable, Optional
1011
from unittest.mock import patch
1112

1213
import torch
@@ -56,7 +57,7 @@ class CompilerManager:
5657
"""
5758

5859
def __init__(self, compilation_config: CompilationConfig):
59-
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
60+
self.cache: dict[tuple[Optional[int], int, str], Any] = dict()
6061
self.is_cache_updated = False
6162
self.compilation_config = compilation_config
6263
self.compiler = make_compiler(compilation_config)
@@ -90,7 +91,7 @@ def save_to_file(self):
9091

9192
def load(self,
9293
graph: fx.GraphModule,
93-
example_inputs: List[Any],
94+
example_inputs: list[Any],
9495
graph_index: int,
9596
runtime_shape: Optional[int] = None) -> Optional[Callable]:
9697
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
@@ -186,7 +187,7 @@ class SplitItem:
186187

187188

188189
def split_graph(graph: fx.GraphModule,
189-
ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
190+
ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]:
190191
# split graph by ops
191192
subgraph_id = 0
192193
node_to_subgraph_id = {}
@@ -252,7 +253,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
252253
"""
253254

254255
def __init__(self, module: torch.fx.GraphModule,
255-
compile_submod_names: List[str], vllm_config: VllmConfig,
256+
compile_submod_names: list[str], vllm_config: VllmConfig,
256257
graph_pool, vllm_backend: "VllmBackend"):
257258
super().__init__(module)
258259
from torch._guards import detect_fake_mode
@@ -274,8 +275,8 @@ def run(self, *args):
274275
return super().run(*fake_args)
275276

276277
def call_module(self, target: torch.fx.node.Target,
277-
args: Tuple[torch.fx.node.Argument,
278-
...], kwargs: Dict[str, Any]) -> Any:
278+
args: tuple[torch.fx.node.Argument,
279+
...], kwargs: dict[str, Any]) -> Any:
279280
assert isinstance(target, str)
280281
output = super().call_module(target, args, kwargs)
281282

@@ -326,12 +327,12 @@ class VllmBackend:
326327
graph: fx.GraphModule
327328
# the stiching graph module for all the piecewise graphs
328329
split_gm: fx.GraphModule
329-
piecewise_graphs: List[SplitItem]
330+
piecewise_graphs: list[SplitItem]
330331
returned_callable: Callable
331332
# Inductor passes to run on the graph pre-defunctionalization
332333
post_grad_passes: Sequence[Callable]
333-
sym_tensor_indices: List[int]
334-
input_buffers: List[torch.Tensor]
334+
sym_tensor_indices: list[int]
335+
input_buffers: list[torch.Tensor]
335336
compiler_manager: CompilerManager
336337

337338
def __init__(
@@ -573,14 +574,14 @@ class ConcreteSizeEntry:
573574

574575
# for cudagraph debugging, track the input addresses
575576
# during capture, and check if they are the same during replay
576-
input_addresses: Optional[List[int]] = None
577+
input_addresses: Optional[list[int]] = None
577578

578579

579580
class PiecewiseBackend:
580581

581582
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
582583
graph_pool: Any, piecewise_compile_index: int,
583-
total_piecewise_compiles: int, sym_shape_indices: List[int],
584+
total_piecewise_compiles: int, sym_shape_indices: list[int],
584585
compiled_graph_for_general_shape: Callable,
585586
vllm_backend: VllmBackend):
586587
"""
@@ -608,9 +609,9 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
608609
self.is_last_graph = (
609610
piecewise_compile_index == total_piecewise_compiles - 1)
610611

611-
self.compile_sizes: Set[int] = set(
612+
self.compile_sizes: set[int] = set(
612613
self.compilation_config.compile_sizes)
613-
self.cudagraph_capture_sizes: Set[int] = set(
614+
self.cudagraph_capture_sizes: set[int] = set(
614615
self.compilation_config.cudagraph_capture_sizes
615616
) if self.compilation_config.use_cudagraph else set()
616617

@@ -624,11 +625,11 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
624625

625626
# the entries for different shapes that we need to either
626627
# compile or capture cudagraph
627-
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
628+
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
628629

629630
# to_be_compiled_sizes tracks the remaining sizes to compile,
630631
# and updates during the compilation process, so we need to copy it
631-
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
632+
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
632633
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
633634
self.concrete_size_entries[shape] = ConcreteSizeEntry(
634635
runtime_shape=shape,

vllm/compilation/compiler_interface.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import hashlib
55
import os
66
from contextlib import ExitStack
7-
from typing import Any, Callable, Dict, List, Optional, Tuple
7+
from typing import Any, Callable, Optional
88
from unittest.mock import patch
99

1010
import torch
@@ -48,11 +48,11 @@ def compute_hash(self, vllm_config: VllmConfig) -> str:
4848
def compile(
4949
self,
5050
graph: fx.GraphModule,
51-
example_inputs: List[Any],
52-
compiler_config: Dict[str, Any],
51+
example_inputs: list[Any],
52+
compiler_config: dict[str, Any],
5353
runtime_shape: Optional[int] = None,
5454
key: Optional[str] = None,
55-
) -> Tuple[Optional[Callable], Optional[Any]]:
55+
) -> tuple[Optional[Callable], Optional[Any]]:
5656
"""
5757
Compile the graph with the given example inputs and compiler config,
5858
with a runtime shape. If the `runtime_shape` is None, it means
@@ -82,7 +82,7 @@ def compile(
8282
def load(self,
8383
handle: Any,
8484
graph: fx.GraphModule,
85-
example_inputs: List[Any],
85+
example_inputs: list[Any],
8686
graph_index: int,
8787
runtime_shape: Optional[int] = None) -> Callable:
8888
"""
@@ -120,7 +120,7 @@ class AlwaysHitShapeEnv:
120120
"""
121121

122122
def __init__(self) -> None:
123-
self.guards: List[Any] = []
123+
self.guards: list[Any] = []
124124

125125
def evaluate_guards_expression(self, *args, **kwargs):
126126
return True
@@ -132,8 +132,8 @@ def produce_guards_expression(self, *args, **kwargs):
132132
return ""
133133

134134

135-
def get_inductor_factors() -> List[Any]:
136-
factors: List[Any] = []
135+
def get_inductor_factors() -> list[Any]:
136+
factors: list[Any] = []
137137
# summarize system state
138138
from torch._inductor.codecache import CacheBase
139139
system_factors = CacheBase.get_system()
@@ -169,11 +169,11 @@ def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
169169
def compile(
170170
self,
171171
graph: fx.GraphModule,
172-
example_inputs: List[Any],
173-
compiler_config: Dict[str, Any],
172+
example_inputs: list[Any],
173+
compiler_config: dict[str, Any],
174174
runtime_shape: Optional[int] = None,
175175
key: Optional[str] = None,
176-
) -> Tuple[Optional[Callable], Optional[Any]]:
176+
) -> tuple[Optional[Callable], Optional[Any]]:
177177
current_config = {}
178178
if compiler_config is not None:
179179
current_config.update(compiler_config)
@@ -201,7 +201,7 @@ def compile(
201201
def load(self,
202202
handle: Any,
203203
graph: fx.GraphModule,
204-
example_inputs: List[Any],
204+
example_inputs: list[Any],
205205
graph_index: int,
206206
runtime_shape: Optional[int] = None) -> Callable:
207207
assert isinstance(handle, tuple)
@@ -256,11 +256,11 @@ def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
256256
def compile(
257257
self,
258258
graph: fx.GraphModule,
259-
example_inputs: List[Any],
260-
compiler_config: Dict[str, Any],
259+
example_inputs: list[Any],
260+
compiler_config: dict[str, Any],
261261
runtime_shape: Optional[int] = None,
262262
key: Optional[str] = None,
263-
) -> Tuple[Optional[Callable], Optional[Any]]:
263+
) -> tuple[Optional[Callable], Optional[Any]]:
264264
from torch._inductor.compile_fx import compile_fx
265265
current_config = {}
266266
if compiler_config is not None:
@@ -420,7 +420,7 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
420420
def load(self,
421421
handle: Any,
422422
graph: fx.GraphModule,
423-
example_inputs: List[Any],
423+
example_inputs: list[Any],
424424
graph_index: int,
425425
runtime_shape: Optional[int] = None) -> Callable:
426426
assert isinstance(handle, tuple)
@@ -522,11 +522,11 @@ class EagerAdaptor(CompilerInterface):
522522
def compile(
523523
self,
524524
graph: fx.GraphModule,
525-
example_inputs: List[Any],
526-
compiler_config: Dict[str, Any],
525+
example_inputs: list[Any],
526+
compiler_config: dict[str, Any],
527527
runtime_shape: Optional[int] = None,
528528
key: Optional[str] = None,
529-
) -> Tuple[Optional[Callable], Optional[Any]]:
529+
) -> tuple[Optional[Callable], Optional[Any]]:
530530
# we don't need to compile the graph, just return the graph itself.
531531
# It does not support caching, return None for the handle.
532532
return graph, None

vllm/compilation/decorators.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import inspect
4-
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
4+
from typing import Callable, Optional, TypeVar, Union, overload
55
from unittest.mock import patch
66

77
import torch
@@ -25,7 +25,7 @@
2525
@overload
2626
def support_torch_compile(
2727
*,
28-
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
28+
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]],
2929
) -> Callable[[_T], _T]:
3030
...
3131

@@ -38,7 +38,7 @@ def support_torch_compile(cls: _T) -> _T:
3838
def support_torch_compile(
3939
cls: Optional[_T] = None,
4040
*,
41-
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
41+
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
4242
) -> Union[Callable[[_T], _T], _T]:
4343
"""
4444
A decorator to add support for compiling the forward method of a class.
@@ -131,7 +131,7 @@ def cls_decorator_helper(cls: _T) -> _T:
131131

132132
def _support_torch_compile(
133133
cls: _T,
134-
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
134+
dynamic_arg_dims: dict[str, Union[int, list[int]]],
135135
) -> _T:
136136
"""
137137
A decorator to add support for compiling the forward method of a class.

vllm/compilation/fix_functionalization.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import operator
4-
from typing import Dict, Iterable, List, Optional, Tuple, Union
4+
from collections.abc import Iterable
5+
from typing import Optional, Union
56

67
import torch
78
from torch._higher_order_ops.auto_functionalize import auto_functionalized
@@ -27,7 +28,7 @@ def __call__(self, graph: torch.fx.Graph):
2728
self.begin()
2829
self.dump_graph(graph, "before_fix_functionalization")
2930

30-
self.nodes_to_remove: List[torch.fx.Node] = []
31+
self.nodes_to_remove: list[torch.fx.Node] = []
3132
count = 0
3233
for node in graph.nodes:
3334
if not is_func(node, auto_functionalized):
@@ -117,8 +118,8 @@ def _remove(self, node_or_nodes: Union[torch.fx.Node,
117118
def defunctionalize(self,
118119
graph: torch.fx.Graph,
119120
node: torch.fx.Node,
120-
mutated_args: Dict[int, Union[torch.fx.Node, str]],
121-
args: Optional[Tuple[Union[torch.fx.Node, str],
121+
mutated_args: dict[int, Union[torch.fx.Node, str]],
122+
args: Optional[tuple[Union[torch.fx.Node, str],
122123
...]] = None):
123124
"""
124125
De-functionalize a node by replacing it with a call to the original.
@@ -130,7 +131,7 @@ def defunctionalize(self,
130131
self._remove(node)
131132

132133
def replace_users_with_mutated_args(self, node: torch.fx.Node,
133-
mutated_args: Dict[int,
134+
mutated_args: dict[int,
134135
Union[torch.fx.Node,
135136
str]]):
136137
"""
@@ -146,7 +147,7 @@ def replace_users_with_mutated_args(self, node: torch.fx.Node,
146147
user.replace_all_uses_with(arg)
147148
self._remove(user)
148149

149-
def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]:
150+
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
150151
"""
151152
Returns the operator.getitem users of the auto-functionalized node,
152153
indexed by the index they are getting.
@@ -161,7 +162,7 @@ def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]:
161162
def insert_defunctionalized(self,
162163
graph: torch.fx.Graph,
163164
node: torch.fx.Node,
164-
args: Optional[Tuple[Union[torch.fx.Node, str],
165+
args: Optional[tuple[Union[torch.fx.Node, str],
165166
...]] = None):
166167
"""
167168
Insert a new defunctionalized node into the graph before node.

vllm/compilation/fusion.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
3+
from typing import Callable, NamedTuple, Optional
44

55
import torch
66
import torch._inductor.pattern_matcher as pm
@@ -57,7 +57,7 @@ def __str__(self):
5757
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
5858
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
5959

60-
QUANT_OPS: Dict[QuantKey, OpOverload] = {
60+
QUANT_OPS: dict[QuantKey, OpOverload] = {
6161
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
6262
kFp8DynamicTensorSym:
6363
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
@@ -80,7 +80,7 @@ def __str__(self):
8080
f"{'' if self.fused_add else 'out'} residual)")
8181

8282

83-
FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = {
83+
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
8484
FusedRMSQuantKey(kFp8StaticTensorSym, False):
8585
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
8686
FusedRMSQuantKey(kFp8StaticTensorSym, True):
@@ -101,7 +101,7 @@ def __init__(self, match: pm.Match, quant_op, fused_op):
101101
self.QUANT_OP = quant_op # in-place quant op
102102
self.FUSED_OP = fused_op # in-place fused quant op
103103

104-
def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node,
104+
def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node,
105105
int]],
106106
**kwargs):
107107
"""
@@ -548,7 +548,7 @@ def __init__(self, config: VllmConfig):
548548
"FusionPass singleton instance already exists"
549549
super().__init__(config)
550550

551-
self.matches: List[MultiOutputMatch] = []
551+
self.matches: list[MultiOutputMatch] = []
552552
self.patterns: PatternMatcherPass = PatternMatcherPass(
553553
pass_name="fusion_pass")
554554

vllm/compilation/fx_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import operator
4-
from typing import Iterable, Optional
4+
from collections.abc import Iterable
5+
from typing import Optional
56

67
from torch import fx
78
from torch._higher_order_ops.auto_functionalize import auto_functionalized

0 commit comments

Comments
 (0)