|
17 | 17 | """PyTorch Mllama model."""
|
18 | 18 | import math
|
19 | 19 | from collections.abc import Iterable, Mapping, Sequence
|
20 |
| -from typing import Literal, Optional, TypedDict, Union |
| 20 | +from typing import Annotated, Literal, Optional, Union |
21 | 21 |
|
22 | 22 | import numpy as np
|
23 | 23 | import torch
|
|
64 | 64 | EncDecMultiModalProcessor,
|
65 | 65 | PromptReplacement, PromptUpdate)
|
66 | 66 | from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
| 67 | +from vllm.utils.tensor_schema import TensorSchema, TensorShape |
67 | 68 |
|
68 | 69 | from .clip import CLIPMLP
|
69 | 70 | from .interfaces import SupportsMultiModal, SupportsV0Only
|
|
73 | 74 | logger = init_logger(__name__)
|
74 | 75 |
|
75 | 76 |
|
76 |
| -class MllamaImagePixelInputs(TypedDict): |
77 |
| - type: Literal["pixel_values"] |
78 |
| - data: torch.Tensor |
79 |
| - """Shape: """ |
80 |
| - """(batch_size, max_num_image, max_num_chunk, num_channel, height, width)""" |
81 |
| - aspect_ratio_ids: torch.Tensor |
82 |
| - """Shape: `(batch_size, max_num_image)`""" |
83 |
| - aspect_ratio_mask: torch.Tensor |
84 |
| - """Shape: `(batch_size, max_num_image, max_num_tiles)`""" |
| 77 | +class MllamaImagePixelInputs(TensorSchema): |
| 78 | + """ |
| 79 | + Dimensions: |
| 80 | + - batch_size: Batch size |
| 81 | + - max_num_image: Max number of images |
| 82 | + - max_num_chunk: Max number of chunks |
| 83 | + - max_num_tiles: Max number of tiles per image |
| 84 | + - num_channel: Number of channels |
| 85 | + - height: Height |
| 86 | + - width: Width |
| 87 | + """ |
| 88 | + |
| 89 | + type: Literal["pixel_values"] = "pixel_values" |
| 90 | + |
| 91 | + data: Annotated[torch.Tensor, |
| 92 | + TensorShape("batch_size", "max_num_image", "max_num_chunk", |
| 93 | + "num_channel", "height", "width")] |
| 94 | + |
| 95 | + aspect_ratio_ids: Annotated[torch.Tensor, |
| 96 | + TensorShape("batch_size", "max_num_image")] |
| 97 | + |
| 98 | + aspect_ratio_mask: Annotated[ |
| 99 | + torch.Tensor, |
| 100 | + TensorShape("batch_size", "max_num_image", "max_num_tiles")] |
85 | 101 |
|
86 | 102 |
|
87 | 103 | # TODO: support LlamaImageEmbeddingInputs
|
|
0 commit comments