Skip to content

Commit f5f14dc

Browse files
authored
Gaudi: Fix llava-next and mllama crash issue (#3127)
Signed-off-by: yuanwu <[email protected]>
1 parent 54d1546 commit f5f14dc

File tree

2 files changed

+116
-72
lines changed

2 files changed

+116
-72
lines changed

backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py

Lines changed: 111 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.utils.checkpoint
2121
import numpy as np
2222

23+
from loguru import logger
2324
from transformers.models.llava_next.modeling_llava_next import (
2425
unpad_image,
2526
)
@@ -92,23 +93,6 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
9293

9394
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
9495

95-
def _merge_input_ids_with_image_features(
96-
self,
97-
inputs_embeds: torch.Tensor,
98-
image_features: torch.Tensor,
99-
input_ids: torch.Tensor,
100-
):
101-
"""In place merges in vision_embeddings with inputs_embeds."""
102-
mask = input_ids == self.config.image_token_index
103-
# Let's pray we have enabled enough slots !
104-
try:
105-
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
106-
except Exception as e:
107-
raise RuntimeError(
108-
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
109-
)
110-
return inputs_embeds
111-
11296
def forward(
11397
self,
11498
input_ids: torch.LongTensor = None,
@@ -169,6 +153,92 @@ def forward(
169153

170154
return outputs
171155

156+
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
157+
def pack_image_features(
158+
self,
159+
image_features,
160+
image_sizes,
161+
vision_feature_select_strategy,
162+
image_newline=None,
163+
):
164+
"""
165+
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
166+
167+
Args:
168+
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
169+
List of image feature tensor, each contains all the visual feature of all patches.
170+
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
171+
Actual image size of each images (H, W).
172+
vision_feature_select_strategy (`str`)
173+
The feature selection strategy used to select the vision feature from the vision backbone.
174+
image_newline (`torch.Tensor` of shape `(embed_dim)`)
175+
New line embedding vector.
176+
Returns:
177+
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
178+
feature_lens (`List[int]`)
179+
token length of each image in image_features
180+
"""
181+
new_image_features = []
182+
feature_lens = []
183+
for image_idx, image_feature in enumerate(image_features):
184+
if image_feature.shape[0] > 1:
185+
base_image_feature = image_feature[0]
186+
image_feature = image_feature[1:]
187+
height = width = (
188+
self.config.vision_config.image_size
189+
// self.config.vision_config.patch_size
190+
)
191+
192+
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
193+
image_sizes[image_idx],
194+
self.config.image_grid_pinpoints,
195+
self.config.vision_config.image_size,
196+
)
197+
198+
if (
199+
np.prod(image_feature.shape)
200+
% (num_patch_height * num_patch_width * height * width)
201+
!= 0
202+
and vision_feature_select_strategy == "default"
203+
):
204+
logger.warning_once(
205+
"Image feature shape does not line up with the provided patch size. "
206+
"You may be using the `default` vision_feature_select_strategy with a"
207+
" visual encoder that does not have CLS."
208+
)
209+
210+
image_feature = image_feature.view(
211+
num_patch_height, num_patch_width, height, width, -1
212+
)
213+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
214+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
215+
image_feature = unpad_image(image_feature, image_sizes[image_idx])
216+
if image_newline is not None:
217+
image_feature = torch.cat(
218+
(
219+
image_feature,
220+
image_newline[:, None, None]
221+
.expand(*image_feature.shape[:-1], 1)
222+
.to(image_feature.device, image_feature.dtype),
223+
),
224+
dim=-1,
225+
)
226+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
227+
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
228+
else:
229+
image_feature = image_feature[0]
230+
if image_newline is not None:
231+
image_feature = torch.cat(
232+
(image_feature, image_newline[None].to(image_feature)), dim=0
233+
)
234+
new_image_features.append(image_feature)
235+
feature_lens.append(image_feature.size(0))
236+
image_features = torch.cat(new_image_features, dim=0)
237+
feature_lens = torch.tensor(
238+
feature_lens, dtype=torch.long, device=image_features.device
239+
)
240+
return image_features, feature_lens
241+
172242
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
173243
def get_image_features(
174244
self,
@@ -303,61 +373,33 @@ def prepare_inputs_for_generation(
303373
)
304374

305375
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
306-
height = width = (
307-
self.config.vision_config.image_size
308-
// self.config.vision_config.patch_size
376+
image_features, feature_lens = self.pack_image_features(
377+
image_features,
378+
image_sizes,
379+
vision_feature_select_strategy=vision_feature_select_strategy,
380+
image_newline=self.image_newline,
309381
)
310382

311-
new_image_features = []
312-
for image_idx, image_feature in enumerate(image_features):
313-
if image_feature.shape[0] > 1:
314-
base_image_feature = image_feature[0]
315-
image_feature = image_feature[1:]
316-
317-
if height * width != base_image_feature.shape[0]:
318-
raise ValueError(
319-
"The number of patches is not consistent with the image size."
320-
)
321-
322-
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
323-
image_sizes[image_idx].tolist(),
324-
self.config.image_grid_pinpoints,
325-
self.config.vision_config.image_size,
326-
)
327-
328-
image_feature = image_feature.view(
329-
num_patch_height, num_patch_width, height, width, -1
330-
)
331-
image_feature = image_feature.permute(
332-
4, 0, 2, 1, 3
333-
).contiguous()
334-
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
335-
image_feature = unpad_image(
336-
image_feature, image_sizes[image_idx]
337-
)
338-
image_feature = torch.cat(
339-
(
340-
image_feature,
341-
self.image_newline[:, None, None].expand(
342-
*image_feature.shape[:-1], 1
343-
),
344-
),
345-
dim=-1,
346-
)
347-
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
348-
image_feature = torch.cat(
349-
(base_image_feature, image_feature), dim=0
350-
)
351-
else:
352-
image_feature = image_feature[0]
353-
image_feature = torch.cat(
354-
(image_feature, self.image_newline[None]), dim=0
355-
)
356-
new_image_features.append(image_feature)
357-
image_features = torch.cat(new_image_features, dim=0)
358-
inputs_embeds = self._merge_input_ids_with_image_features(
359-
inputs_embeds, image_features, input_ids
383+
special_image_mask = (
384+
input_ids == self.config.image_token_index
385+
).unsqueeze(-1)
386+
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
387+
inputs_embeds.device
388+
)
389+
if inputs_embeds[special_image_mask].numel() != image_features.numel():
390+
n_image_tokens = (input_ids == self.config.image_token_index).sum()
391+
n_image_features = image_features.shape[0]
392+
raise ValueError(
393+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
394+
)
395+
396+
image_features = image_features.to(
397+
inputs_embeds.device, inputs_embeds.dtype
360398
)
399+
inputs_embeds = inputs_embeds.masked_scatter(
400+
special_image_mask, image_features
401+
)
402+
361403
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
362404
# generation with cache
363405
elif past_key_values is not None:

backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,9 @@ def batch_tokenized_inputs(
428428
else:
429429
images.append(curr_image)
430430

431+
if is_warmup is True:
432+
images += [images[0]] * (len(texts) - len(images))
433+
431434
missing_inputs = 0
432435
dummy_images = None
433436
if is_warmup is False:
@@ -1464,7 +1467,6 @@ def warmup(
14641467
batch = self.batch_from_pb(request.batch, is_warmup=True)
14651468
max_input_tokens = request.max_input_tokens
14661469
max_prefill_batch_size = batch.input_ids.shape[0]
1467-
14681470
try:
14691471
# max prefill batch size warmup
14701472
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
@@ -1548,7 +1550,7 @@ def warmup(
15481550
request,
15491551
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
15501552
max_prefill_batch_size,
1551-
is_warmup=False,
1553+
is_warmup=True,
15521554
)
15531555
_, prefill_batch, _ = self.generate_token(
15541556
[batch], is_warmup=True
@@ -1568,7 +1570,7 @@ def warmup(
15681570
request,
15691571
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
15701572
2,
1571-
is_warmup=False,
1573+
is_warmup=True,
15721574
)
15731575
_, prefill_batch, _ = self.generate_token(
15741576
[batch], is_warmup=True

0 commit comments

Comments
 (0)