|
42 | 42 | SequenceData)
|
43 | 43 |
|
44 | 44 | from .interfaces import SupportsMultiModal
|
45 |
| -from .utils import merge_multimodal_embeddings |
| 45 | +from .utils import flatten_bn, merge_multimodal_embeddings |
46 | 46 |
|
47 | 47 | # Cannot find the following 2 numbers from hf config.
|
48 | 48 | _IMAGE_TOKEN_ID = 71011
|
@@ -165,7 +165,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
|
165 | 165 | model_config.model)
|
166 | 166 |
|
167 | 167 | model_image_input = _fuyu_image_preprocess(image_processor, image_data)
|
168 |
| - image_patches = torch.stack([ |
| 168 | + image_patches = torch.cat([ |
169 | 169 | image_patch[0]
|
170 | 170 | for image_patch in model_image_input["image_patches"]
|
171 | 171 | ])
|
@@ -210,7 +210,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
|
210 | 210 | ])
|
211 | 211 |
|
212 | 212 | # image has been processed with prompt in input processor
|
213 |
| - return MultiModalInputs({"image_patches": data}) |
| 213 | + return MultiModalInputs({"pixel_values": data}) |
214 | 214 |
|
215 | 215 |
|
216 | 216 | @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
|
@@ -242,23 +242,42 @@ def __init__(self,
|
242 | 242 | cache_config=cache_config,
|
243 | 243 | quant_config=quant_config)
|
244 | 244 |
|
| 245 | + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: |
| 246 | + |
| 247 | + h = w = self.config.patch_size |
| 248 | + num_channels = self.config.num_channels |
| 249 | + expected_dims = num_channels * h * w |
| 250 | + |
| 251 | + def _validate_shape(d: torch.Tensor): |
| 252 | + actual_dims = d.size(-1) |
| 253 | + |
| 254 | + if actual_dims != expected_dims: |
| 255 | + expected_expr = str(expected_dims) |
| 256 | + raise ValueError( |
| 257 | + "The expected shape of pixel values per image per batch " |
| 258 | + f" per patch is {expected_expr}. " |
| 259 | + f"You supplied {tuple(d.shape)}.") |
| 260 | + |
| 261 | + for d in data: |
| 262 | + _validate_shape(d) |
| 263 | + |
| 264 | + return data.to(self.vision_embed_tokens.weight.dtype) |
| 265 | + |
245 | 266 | def _parse_and_validate_image_input(
|
246 | 267 | self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
|
247 |
| - image_patches = kwargs.pop("image_patches", None) |
| 268 | + pixel_values = kwargs.pop("pixel_values", None) |
248 | 269 |
|
249 |
| - if isinstance(image_patches, torch.Tensor): |
250 |
| - # Remove the N dimension until multiple images are supported. |
251 |
| - image_patches = image_patches.squeeze(1) |
| 270 | + if pixel_values is not None: |
| 271 | + if not isinstance(pixel_values, (torch.Tensor, list)): |
| 272 | + raise ValueError("Incorrect type of image patches. " |
| 273 | + f"Got type: {type(pixel_values)}") |
| 274 | + |
| 275 | + return FuyuImagePixelInputs( |
| 276 | + type="pixel_values", |
| 277 | + data=self._validate_pixel_values( |
| 278 | + flatten_bn(pixel_values, concat=True)), |
| 279 | + ) |
252 | 280 |
|
253 |
| - expected_feature_size = self.image_feature_size |
254 |
| - if image_patches.size(-1) != expected_feature_size: |
255 |
| - raise ValueError( |
256 |
| - f"Expected image patches to have the last dimension of " |
257 |
| - f"{expected_feature_size}, got {image_patches.size(-1)}") |
258 |
| - image_patches = image_patches.to( |
259 |
| - self.vision_embed_tokens.weight.dtype) |
260 |
| - return FuyuImagePixelInputs(type="pixel_values", |
261 |
| - data=image_patches) |
262 | 281 | return None
|
263 | 282 |
|
264 | 283 | def _process_image_input(
|
|
0 commit comments