Skip to content

Commit 042a496

Browse files
bbeckcaDarkLight1337
authored andcommitted
Migrate MllamaImagePixelInputs to TensorSchema (vllm-project#22020)
Signed-off-by: Benji Beck <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent a23605e commit 042a496

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

vllm/model_executor/models/mllama.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""PyTorch Mllama model."""
1818
import math
1919
from collections.abc import Iterable, Mapping, Sequence
20-
from typing import Literal, Optional, TypedDict, Union
20+
from typing import Annotated, Literal, Optional, Union
2121

2222
import numpy as np
2323
import torch
@@ -64,6 +64,7 @@
6464
EncDecMultiModalProcessor,
6565
PromptReplacement, PromptUpdate)
6666
from vllm.multimodal.profiling import BaseDummyInputsBuilder
67+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
6768

6869
from .clip import CLIPMLP
6970
from .interfaces import SupportsMultiModal, SupportsV0Only
@@ -73,15 +74,30 @@
7374
logger = init_logger(__name__)
7475

7576

76-
class MllamaImagePixelInputs(TypedDict):
77-
type: Literal["pixel_values"]
78-
data: torch.Tensor
79-
"""Shape: """
80-
"""(batch_size, max_num_image, max_num_chunk, num_channel, height, width)"""
81-
aspect_ratio_ids: torch.Tensor
82-
"""Shape: `(batch_size, max_num_image)`"""
83-
aspect_ratio_mask: torch.Tensor
84-
"""Shape: `(batch_size, max_num_image, max_num_tiles)`"""
77+
class MllamaImagePixelInputs(TensorSchema):
78+
"""
79+
Dimensions:
80+
- batch_size: Batch size
81+
- max_num_image: Max number of images
82+
- max_num_chunk: Max number of chunks
83+
- max_num_tiles: Max number of tiles per image
84+
- num_channel: Number of channels
85+
- height: Height
86+
- width: Width
87+
"""
88+
89+
type: Literal["pixel_values"] = "pixel_values"
90+
91+
data: Annotated[torch.Tensor,
92+
TensorShape("batch_size", "max_num_image", "max_num_chunk",
93+
"num_channel", "height", "width")]
94+
95+
aspect_ratio_ids: Annotated[torch.Tensor,
96+
TensorShape("batch_size", "max_num_image")]
97+
98+
aspect_ratio_mask: Annotated[
99+
torch.Tensor,
100+
TensorShape("batch_size", "max_num_image", "max_num_tiles")]
85101

86102

87103
# TODO: support LlamaImageEmbeddingInputs

0 commit comments

Comments
 (0)