|
20 | 20 | import torch.utils.checkpoint
|
21 | 21 | import numpy as np
|
22 | 22 |
|
| 23 | +from loguru import logger |
23 | 24 | from transformers.models.llava_next.modeling_llava_next import (
|
24 | 25 | unpad_image,
|
25 | 26 | )
|
@@ -92,23 +93,6 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
92 | 93 |
|
93 | 94 | class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
94 | 95 |
|
95 |
| - def _merge_input_ids_with_image_features( |
96 |
| - self, |
97 |
| - inputs_embeds: torch.Tensor, |
98 |
| - image_features: torch.Tensor, |
99 |
| - input_ids: torch.Tensor, |
100 |
| - ): |
101 |
| - """In place merges in vision_embeddings with inputs_embeds.""" |
102 |
| - mask = input_ids == self.config.image_token_index |
103 |
| - # Let's pray we have enabled enough slots ! |
104 |
| - try: |
105 |
| - inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) |
106 |
| - except Exception as e: |
107 |
| - raise RuntimeError( |
108 |
| - f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}" |
109 |
| - ) |
110 |
| - return inputs_embeds |
111 |
| - |
112 | 96 | def forward(
|
113 | 97 | self,
|
114 | 98 | input_ids: torch.LongTensor = None,
|
@@ -169,6 +153,92 @@ def forward(
|
169 | 153 |
|
170 | 154 | return outputs
|
171 | 155 |
|
| 156 | + # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411 |
| 157 | + def pack_image_features( |
| 158 | + self, |
| 159 | + image_features, |
| 160 | + image_sizes, |
| 161 | + vision_feature_select_strategy, |
| 162 | + image_newline=None, |
| 163 | + ): |
| 164 | + """ |
| 165 | + Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. |
| 166 | +
|
| 167 | + Args: |
| 168 | + image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) |
| 169 | + List of image feature tensor, each contains all the visual feature of all patches. |
| 170 | + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) |
| 171 | + Actual image size of each images (H, W). |
| 172 | + vision_feature_select_strategy (`str`) |
| 173 | + The feature selection strategy used to select the vision feature from the vision backbone. |
| 174 | + image_newline (`torch.Tensor` of shape `(embed_dim)`) |
| 175 | + New line embedding vector. |
| 176 | + Returns: |
| 177 | + image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) |
| 178 | + feature_lens (`List[int]`) |
| 179 | + token length of each image in image_features |
| 180 | + """ |
| 181 | + new_image_features = [] |
| 182 | + feature_lens = [] |
| 183 | + for image_idx, image_feature in enumerate(image_features): |
| 184 | + if image_feature.shape[0] > 1: |
| 185 | + base_image_feature = image_feature[0] |
| 186 | + image_feature = image_feature[1:] |
| 187 | + height = width = ( |
| 188 | + self.config.vision_config.image_size |
| 189 | + // self.config.vision_config.patch_size |
| 190 | + ) |
| 191 | + |
| 192 | + num_patch_height, num_patch_width = get_anyres_image_grid_shape( |
| 193 | + image_sizes[image_idx], |
| 194 | + self.config.image_grid_pinpoints, |
| 195 | + self.config.vision_config.image_size, |
| 196 | + ) |
| 197 | + |
| 198 | + if ( |
| 199 | + np.prod(image_feature.shape) |
| 200 | + % (num_patch_height * num_patch_width * height * width) |
| 201 | + != 0 |
| 202 | + and vision_feature_select_strategy == "default" |
| 203 | + ): |
| 204 | + logger.warning_once( |
| 205 | + "Image feature shape does not line up with the provided patch size. " |
| 206 | + "You may be using the `default` vision_feature_select_strategy with a" |
| 207 | + " visual encoder that does not have CLS." |
| 208 | + ) |
| 209 | + |
| 210 | + image_feature = image_feature.view( |
| 211 | + num_patch_height, num_patch_width, height, width, -1 |
| 212 | + ) |
| 213 | + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() |
| 214 | + image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
| 215 | + image_feature = unpad_image(image_feature, image_sizes[image_idx]) |
| 216 | + if image_newline is not None: |
| 217 | + image_feature = torch.cat( |
| 218 | + ( |
| 219 | + image_feature, |
| 220 | + image_newline[:, None, None] |
| 221 | + .expand(*image_feature.shape[:-1], 1) |
| 222 | + .to(image_feature.device, image_feature.dtype), |
| 223 | + ), |
| 224 | + dim=-1, |
| 225 | + ) |
| 226 | + image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
| 227 | + image_feature = torch.cat((base_image_feature, image_feature), dim=0) |
| 228 | + else: |
| 229 | + image_feature = image_feature[0] |
| 230 | + if image_newline is not None: |
| 231 | + image_feature = torch.cat( |
| 232 | + (image_feature, image_newline[None].to(image_feature)), dim=0 |
| 233 | + ) |
| 234 | + new_image_features.append(image_feature) |
| 235 | + feature_lens.append(image_feature.size(0)) |
| 236 | + image_features = torch.cat(new_image_features, dim=0) |
| 237 | + feature_lens = torch.tensor( |
| 238 | + feature_lens, dtype=torch.long, device=image_features.device |
| 239 | + ) |
| 240 | + return image_features, feature_lens |
| 241 | + |
172 | 242 | # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
|
173 | 243 | def get_image_features(
|
174 | 244 | self,
|
@@ -303,61 +373,33 @@ def prepare_inputs_for_generation(
|
303 | 373 | )
|
304 | 374 |
|
305 | 375 | # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
306 |
| - height = width = ( |
307 |
| - self.config.vision_config.image_size |
308 |
| - // self.config.vision_config.patch_size |
| 376 | + image_features, feature_lens = self.pack_image_features( |
| 377 | + image_features, |
| 378 | + image_sizes, |
| 379 | + vision_feature_select_strategy=vision_feature_select_strategy, |
| 380 | + image_newline=self.image_newline, |
309 | 381 | )
|
310 | 382 |
|
311 |
| - new_image_features = [] |
312 |
| - for image_idx, image_feature in enumerate(image_features): |
313 |
| - if image_feature.shape[0] > 1: |
314 |
| - base_image_feature = image_feature[0] |
315 |
| - image_feature = image_feature[1:] |
316 |
| - |
317 |
| - if height * width != base_image_feature.shape[0]: |
318 |
| - raise ValueError( |
319 |
| - "The number of patches is not consistent with the image size." |
320 |
| - ) |
321 |
| - |
322 |
| - num_patch_height, num_patch_width = get_anyres_image_grid_shape( |
323 |
| - image_sizes[image_idx].tolist(), |
324 |
| - self.config.image_grid_pinpoints, |
325 |
| - self.config.vision_config.image_size, |
326 |
| - ) |
327 |
| - |
328 |
| - image_feature = image_feature.view( |
329 |
| - num_patch_height, num_patch_width, height, width, -1 |
330 |
| - ) |
331 |
| - image_feature = image_feature.permute( |
332 |
| - 4, 0, 2, 1, 3 |
333 |
| - ).contiguous() |
334 |
| - image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
335 |
| - image_feature = unpad_image( |
336 |
| - image_feature, image_sizes[image_idx] |
337 |
| - ) |
338 |
| - image_feature = torch.cat( |
339 |
| - ( |
340 |
| - image_feature, |
341 |
| - self.image_newline[:, None, None].expand( |
342 |
| - *image_feature.shape[:-1], 1 |
343 |
| - ), |
344 |
| - ), |
345 |
| - dim=-1, |
346 |
| - ) |
347 |
| - image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
348 |
| - image_feature = torch.cat( |
349 |
| - (base_image_feature, image_feature), dim=0 |
350 |
| - ) |
351 |
| - else: |
352 |
| - image_feature = image_feature[0] |
353 |
| - image_feature = torch.cat( |
354 |
| - (image_feature, self.image_newline[None]), dim=0 |
355 |
| - ) |
356 |
| - new_image_features.append(image_feature) |
357 |
| - image_features = torch.cat(new_image_features, dim=0) |
358 |
| - inputs_embeds = self._merge_input_ids_with_image_features( |
359 |
| - inputs_embeds, image_features, input_ids |
| 383 | + special_image_mask = ( |
| 384 | + input_ids == self.config.image_token_index |
| 385 | + ).unsqueeze(-1) |
| 386 | + special_image_mask = special_image_mask.expand_as(inputs_embeds).to( |
| 387 | + inputs_embeds.device |
| 388 | + ) |
| 389 | + if inputs_embeds[special_image_mask].numel() != image_features.numel(): |
| 390 | + n_image_tokens = (input_ids == self.config.image_token_index).sum() |
| 391 | + n_image_features = image_features.shape[0] |
| 392 | + raise ValueError( |
| 393 | + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| 394 | + ) |
| 395 | + |
| 396 | + image_features = image_features.to( |
| 397 | + inputs_embeds.device, inputs_embeds.dtype |
360 | 398 | )
|
| 399 | + inputs_embeds = inputs_embeds.masked_scatter( |
| 400 | + special_image_mask, image_features |
| 401 | + ) |
| 402 | + |
361 | 403 | # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
362 | 404 | # generation with cache
|
363 | 405 | elif past_key_values is not None:
|
|
0 commit comments