Skip to content

Update deprecated type hinting in vllm/profiler #18057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ exclude = [
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
"vllm/plugins/**/*.py" = ["UP006", "UP035"]
"vllm/profiler/**/*.py" = ["UP006", "UP035"]
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"]
Expand Down
38 changes: 19 additions & 19 deletions vllm/profiler/layerwise_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias, Union
from typing import Any, Callable, Optional, TypeAlias, Union

import pandas as pd
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
Expand All @@ -20,7 +20,7 @@
class _ModuleTreeNode:
event: _ProfilerEvent
parent: Optional['_ModuleTreeNode'] = None
children: List['_ModuleTreeNode'] = field(default_factory=list)
children: list['_ModuleTreeNode'] = field(default_factory=list)
trace: str = ""

@property
Expand Down Expand Up @@ -60,19 +60,19 @@ class ModelStatsEntry:
@dataclass
class _StatsTreeNode:
entry: StatsEntry
children: List[StatsEntry]
children: list[StatsEntry]
parent: Optional[StatsEntry]


@dataclass
class LayerwiseProfileResults(profile):
_kineto_results: _ProfilerResult
_kineto_event_correlation_map: Dict[int,
List[_KinetoEvent]] = field(init=False)
_event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False)
_module_tree: List[_ModuleTreeNode] = field(init=False)
_model_stats_tree: List[_StatsTreeNode] = field(init=False)
_summary_stats_tree: List[_StatsTreeNode] = field(init=False)
_kineto_event_correlation_map: dict[int,
list[_KinetoEvent]] = field(init=False)
_event_correlation_map: dict[int, list[FunctionEvent]] = field(init=False)
_module_tree: list[_ModuleTreeNode] = field(init=False)
_model_stats_tree: list[_StatsTreeNode] = field(init=False)
_summary_stats_tree: list[_StatsTreeNode] = field(init=False)

# profile metadata
num_running_seqs: Optional[int] = None
Expand All @@ -82,7 +82,7 @@ def __post_init__(self):
self._build_module_tree()
self._build_stats_trees()

def print_model_table(self, column_widths: Dict[str, int] = None):
def print_model_table(self, column_widths: dict[str, int] = None):
_column_widths = dict(name=60,
cpu_time_us=12,
cuda_time_us=12,
Expand All @@ -100,7 +100,7 @@ def print_model_table(self, column_widths: Dict[str, int] = None):
filtered_model_table,
indent_style=lambda indent: "|" + "-" * indent + " "))

def print_summary_table(self, column_widths: Dict[str, int] = None):
def print_summary_table(self, column_widths: dict[str, int] = None):
_column_widths = dict(name=80,
cuda_time_us=12,
pct_cuda_time=12,
Expand Down Expand Up @@ -142,7 +142,7 @@ def convert_stats_to_dict(self) -> dict[str, Any]:
}

@staticmethod
def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int,
def _indent_row_names_based_on_depth(depths_rows: list[tuple[int,
StatsEntry]],
indent_style: Union[Callable[[int],
str],
Expand Down Expand Up @@ -229,7 +229,7 @@ def _total_cuda_time(self):
[self._cumulative_cuda_time(root) for root in self._module_tree])

def _build_stats_trees(self):
summary_dict: Dict[str, _StatsTreeNode] = {}
summary_dict: dict[str, _StatsTreeNode] = {}
total_cuda_time = self._total_cuda_time()

def pct_cuda_time(cuda_time_us):
Expand All @@ -238,7 +238,7 @@ def pct_cuda_time(cuda_time_us):
def build_summary_stats_tree_df(
node: _ModuleTreeNode,
parent: Optional[_StatsTreeNode] = None,
summary_trace: Tuple[str] = ()):
summary_trace: tuple[str] = ()):

if event_has_module(node.event):
name = event_module_repr(node.event)
Expand Down Expand Up @@ -313,8 +313,8 @@ def build_model_stats_tree_df(node: _ModuleTreeNode,
self._model_stats_tree.append(build_model_stats_tree_df(root))

def _flatten_stats_tree(
self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]:
entries: List[Tuple[int, StatsEntry]] = []
self, tree: list[_StatsTreeNode]) -> list[tuple[int, StatsEntry]]:
entries: list[tuple[int, StatsEntry]] = []

def df_traversal(node: _StatsTreeNode, depth=0):
entries.append((depth, node.entry))
Expand All @@ -327,10 +327,10 @@ def df_traversal(node: _StatsTreeNode, depth=0):
return entries

def _convert_stats_tree_to_dict(self,
tree: List[_StatsTreeNode]) -> List[Dict]:
root_dicts: List[Dict] = []
tree: list[_StatsTreeNode]) -> list[dict]:
root_dicts: list[dict] = []

def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]):
def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]):
curr_json_list.append({
"entry": asdict(node.entry),
"children": []
Expand Down
8 changes: 4 additions & 4 deletions vllm/profiler/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import dataclasses
from typing import Callable, Dict, List, Type, Union
from typing import Callable, Union

from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata

Expand Down Expand Up @@ -30,14 +30,14 @@ def trim_string_back(string, width):

class TablePrinter:

def __init__(self, row_cls: Type[dataclasses.dataclass],
column_widths: Dict[str, int]):
def __init__(self, row_cls: type[dataclasses.dataclass],
column_widths: dict[str, int]):
self.row_cls = row_cls
self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
self.column_widths = column_widths
assert set(self.column_widths.keys()) == set(self.fieldnames)

def print_table(self, rows: List[dataclasses.dataclass]):
def print_table(self, rows: list[dataclasses.dataclass]):
self._print_header()
self._print_line()
for row in rows:
Expand Down