Skip to content

Commit 5f8d807

Browse files
litianjianlitianjianDarkLight1337
authored
[Model][VLM] Add multi-video support for LLaVA-Onevision (#8905)
Co-authored-by: litianjian <[email protected]> Co-authored-by: DarkLight1337 <[email protected]>
1 parent 8b0e4f2 commit 5f8d807

File tree

5 files changed

+123
-162
lines changed

5 files changed

+123
-162
lines changed

tests/models/decoder_only/vision_language/test_llava_onevision.py

Lines changed: 48 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Tuple, Type, overload
1+
from typing import List, Optional, Tuple, Type
22

33
import pytest
44
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
@@ -9,9 +9,8 @@
99
from vllm.sequence import SampleLogprobs
1010
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
1111

12-
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner,
13-
_VideoAssets)
14-
from ....utils import large_gpu_test
12+
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput,
13+
PromptVideoInput, VllmRunner)
1514
from ...utils import check_logprobs_close
1615

1716
# Video test
@@ -20,7 +19,7 @@
2019
"<|im_start|>user\n<video>\nwhy is this video funny?<|im_end|>\n<|im_start|>assistant\n" # noqa: E501
2120
})
2221

23-
models = ["llava-hf/llava-onevision-qwen2-7b-ov-hf"]
22+
models = ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]
2423

2524

2625
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
@@ -47,50 +46,16 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
4746
return hf_output_ids, hf_output_str, out_logprobs
4847

4948

50-
@overload
51-
def run_video_test(
52-
hf_runner: Type[HfRunner],
53-
vllm_runner: Type[VllmRunner],
54-
video_assets: _VideoAssets,
55-
model: str,
56-
*,
57-
size_factors: List[float],
58-
dtype: str,
59-
max_tokens: int,
60-
num_logprobs: int,
61-
num_frames: int,
62-
tensor_parallel_size: int,
63-
distributed_executor_backend: Optional[str] = None,
64-
):
65-
...
66-
67-
68-
@overload
69-
def run_video_test(
70-
hf_runner: Type[HfRunner],
71-
vllm_runner: Type[VllmRunner],
72-
video_assets: _VideoAssets,
73-
model: str,
74-
*,
75-
sizes: List[Tuple[int, int]],
76-
dtype: str,
77-
max_tokens: int,
78-
num_logprobs: int,
79-
num_frames: int,
80-
tensor_parallel_size: int,
81-
distributed_executor_backend: Optional[str] = None,
82-
):
83-
...
49+
# Video test
50+
_LIMIT_VIDEO_PER_PROMPT = 4
8451

8552

8653
def run_video_test(
8754
hf_runner: Type[HfRunner],
8855
vllm_runner: Type[VllmRunner],
89-
video_assets: _VideoAssets,
56+
inputs: List[Tuple[List[str], PromptVideoInput]],
9057
model: str,
9158
*,
92-
size_factors: Optional[List[float]] = None,
93-
sizes: Optional[List[Tuple[int, int]]] = None,
9459
dtype: str,
9560
max_tokens: int,
9661
num_logprobs: int,
@@ -99,38 +64,20 @@ def run_video_test(
9964
distributed_executor_backend: Optional[str] = None,
10065
):
10166
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
102-
103-
videos = [
104-
sample_frames_from_video(asset.np_ndarrays, num_frames)
105-
for asset in video_assets
106-
]
107-
108-
if size_factors is not None:
109-
inputs_per_video = [(
110-
[prompt for _ in size_factors],
111-
[rescale_video_size(video, factor) for factor in size_factors],
112-
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
113-
elif sizes is not None:
114-
inputs_per_video = [(
115-
[prompt for _ in sizes],
116-
[resize_video(video, size) for size in sizes],
117-
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
118-
else:
119-
raise ValueError("You must provide either `size_factors` or `sizes`")
120-
121-
# max_model_len should be greater than image_feature_size
12267
with vllm_runner(model,
12368
dtype=dtype,
124-
max_model_len=4096,
69+
max_model_len=16384,
12570
tensor_parallel_size=tensor_parallel_size,
12671
distributed_executor_backend=distributed_executor_backend,
127-
enforce_eager=True) as vllm_model:
128-
vllm_outputs_per_video = [
72+
enforce_eager=True,
73+
limit_mm_per_prompt={"video": _LIMIT_VIDEO_PER_PROMPT
74+
}) as vllm_model:
75+
vllm_outputs_per_input = [
12976
vllm_model.generate_greedy_logprobs(prompts,
13077
max_tokens,
13178
num_logprobs=num_logprobs,
13279
videos=videos)
133-
for prompts, videos in inputs_per_video
80+
for prompts, videos in inputs
13481
]
13582

13683
def process(hf_inputs: BatchEncoding):
@@ -142,16 +89,16 @@ def process(hf_inputs: BatchEncoding):
14289
dtype=dtype,
14390
postprocess_inputs=process,
14491
auto_cls=AutoModelForVision2Seq) as hf_model:
145-
hf_outputs_per_video = [
92+
hf_outputs_per_input = [
14693
hf_model.generate_greedy_logprobs_limit(prompts,
14794
max_tokens,
14895
num_logprobs=num_logprobs,
14996
videos=videos)
150-
for prompts, videos in inputs_per_video
97+
for prompts, videos in inputs
15198
]
15299

153-
for hf_outputs, vllm_outputs in zip(hf_outputs_per_video,
154-
vllm_outputs_per_video):
100+
for hf_outputs, vllm_outputs in zip(hf_outputs_per_input,
101+
vllm_outputs_per_input):
155102
# TODO: Check whether using original CLIPVisionModel can improve
156103
# consistency against HF
157104
check_logprobs_close(
@@ -165,74 +112,51 @@ def process(hf_inputs: BatchEncoding):
165112
)
166113

167114

168-
@large_gpu_test(min_gb=48)
169115
@pytest.mark.parametrize("model", models)
170-
@pytest.mark.parametrize(
171-
"size_factors",
172-
[
173-
# No video
174-
[],
175-
# Single-scale
176-
[1.0],
177-
# Single-scale, batched
178-
[1.0, 1.0, 1.0],
179-
# Multi-scale
180-
[0.25, 0.5, 1.0],
181-
],
182-
)
183116
@pytest.mark.parametrize("dtype", ["half"])
184117
@pytest.mark.parametrize("max_tokens", [128])
185118
@pytest.mark.parametrize("num_logprobs", [5])
186119
@pytest.mark.parametrize("num_frames", [16])
187-
def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
188-
dtype, max_tokens, num_logprobs, num_frames) -> None:
189-
"""Inference result should be the same between hf and vllm.
190-
191-
All the image fixtures for the test is under tests/videos.
192-
For huggingface runner, we provide the np.ndarray as input.
193-
For vllm runner, we provide MultiModalDataDict objects
194-
and corresponding MultiModalConfig as input.
195-
Note, the text input is also adjusted to abide by vllm contract.
196-
The text output is sanitized to be able to compare with hf.
197-
"""
120+
def test_models_multiple_video_inputs(hf_runner, vllm_runner, video_assets,
121+
model, dtype, max_tokens, num_logprobs,
122+
num_frames) -> None:
123+
video = sample_frames_from_video(video_assets[0].np_ndarrays, num_frames)
124+
inputs = [(
125+
[
126+
"<|im_start|>user <video><video>\nDescribe 2 videos. \
127+
<|im_end|><|im_start|>assistant\n",
128+
"<|im_start|>user <video><video>\nDescribe 2 videos. \
129+
<|im_end|><|im_start|>assistant\n",
130+
"<|im_start|>user <video><video><video><video>\nDescribe 4 videos. \
131+
<|im_end|><|im_start|>assistant\n",
132+
"<|im_start|>user <video>\nwhy is this video funny? \
133+
<|im_end|><|im_start|>assistant\n",
134+
],
135+
[
136+
[video, video],
137+
# Images with different sizes and aspect-ratios
138+
[
139+
rescale_video_size(video, 0.1),
140+
video,
141+
],
142+
[
143+
video,
144+
rescale_video_size(video, 0.25),
145+
resize_video(video, (183, 488)),
146+
resize_video(video, (488, 183))
147+
],
148+
video,
149+
])]
198150
run_video_test(
199151
hf_runner,
200152
vllm_runner,
201-
video_assets,
153+
inputs,
202154
model,
203-
size_factors=size_factors,
204155
dtype=dtype,
205156
max_tokens=max_tokens,
206157
num_logprobs=num_logprobs,
207-
num_frames=num_frames,
208158
tensor_parallel_size=1,
209-
)
210-
211-
212-
@large_gpu_test(min_gb=48)
213-
@pytest.mark.parametrize("model", models)
214-
@pytest.mark.parametrize(
215-
"sizes",
216-
[[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
217-
)
218-
@pytest.mark.parametrize("dtype", ["half"])
219-
@pytest.mark.parametrize("max_tokens", [128])
220-
@pytest.mark.parametrize("num_logprobs", [5])
221-
@pytest.mark.parametrize("num_frames", [16])
222-
def test_models_fixed_sizes(hf_runner, vllm_runner, video_assets, model, sizes,
223-
dtype, max_tokens, num_logprobs,
224-
num_frames) -> None:
225-
run_video_test(
226-
hf_runner,
227-
vllm_runner,
228-
video_assets,
229-
model,
230-
sizes=sizes,
231-
dtype=dtype,
232-
max_tokens=max_tokens,
233-
num_logprobs=num_logprobs,
234159
num_frames=num_frames,
235-
tensor_parallel_size=1,
236160
)
237161

238162

@@ -303,7 +227,6 @@ def process(hf_inputs: BatchEncoding):
303227
)
304228

305229

306-
@large_gpu_test(min_gb=48)
307230
@pytest.mark.parametrize("model", models)
308231
@pytest.mark.parametrize("dtype", ["half"])
309232
@pytest.mark.parametrize("max_tokens", [128])

vllm/model_executor/models/clip.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def dummy_image_for_clip(
8888
def dummy_video_for_clip(
8989
hf_config: CLIPVisionConfig,
9090
num_frames: int,
91+
num_videos: int = 1,
9192
*,
9293
image_width_override: Optional[int] = None,
9394
image_height_override: Optional[int] = None,
@@ -99,7 +100,8 @@ def dummy_video_for_clip(
99100
image_height_override=image_height_override)
100101
np_frame = np.array(pil_frame["image"])
101102
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
102-
mm_data = {"video": mm_data_per_video}
103+
video_data = [mm_data_per_video] * num_videos
104+
mm_data = {"video": video_data}
103105
return mm_data
104106

105107

0 commit comments

Comments
 (0)