@@ -74,11 +74,11 @@ def mm_model_cls():
74
74
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
75
75
get_num_crops = lambda ctx , * , num_crops = DEFAULT_NUM_CROPS : num_crops
76
76
custom_mapper = lambda ctx , data , * , num_crops = DEFAULT_NUM_CROPS : {
77
- "num_pixels " : torch .zeros (size = (1 , num_crops + 1 , 3 , 336 , 336 ))
77
+ "pixel_values " : torch .zeros (size = (1 , num_crops + 1 , 3 , 336 , 336 ))
78
78
}
79
79
80
80
81
- ### Test for default processor logic & mm_processor_kwargs wrapping
81
+ ### Tests for default processor logic & mm_processor_kwargs wrapping
82
82
def test_default_processor_is_a_noop ():
83
83
"""Ensure that by default, there is no processor override."""
84
84
dummy_registry = InputRegistry ()
@@ -89,23 +89,46 @@ def test_default_processor_is_a_noop():
89
89
assert proc_inputs is proc_outputs
90
90
91
91
92
- @pytest .mark .parametrize ("num_crops" , [None , NUM_CROPS_OVERRIDE ])
93
- def test_processor_default_kwargs (use_processor_mock , num_crops ):
94
- """Ensure input processors can use processor kwargs."""
95
- dummy_registry = InputRegistry ()
92
+ def _get_num_crops_info (init_num_crops : int , inference_num_crops : int ):
93
+ """Get the init / inference kwargs and expected num_crops for this test."""
96
94
# If we have a value for num_crops, pass the override value and make
97
95
# sure we get that value as a return-value from out mock processor,
98
96
# otherwise fall back to the default value
99
- mm_processor_kwargs = None if num_crops is None else {
100
- "num_crops" : num_crops
97
+ init_kwargs = None if init_num_crops is None else {
98
+ "num_crops" : init_num_crops
101
99
}
102
- expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
103
- ctx = build_model_context (DUMMY_MODEL_ID ,
104
- mm_processor_kwargs = mm_processor_kwargs )
105
- processor = dummy_registry .create_input_processor (ctx .model_config )
100
+ inference_kwargs = None if inference_num_crops is None else {
101
+ "num_crops" : inference_num_crops
102
+ }
103
+ if inference_num_crops is not None :
104
+ expected_seq_count = inference_num_crops
105
+ elif init_num_crops is not None :
106
+ expected_seq_count = init_num_crops
107
+ else :
108
+ expected_seq_count = DEFAULT_NUM_CROPS
109
+ return init_kwargs , inference_kwargs , expected_seq_count
110
+
111
+
112
+ @pytest .mark .parametrize ("init_num_crops,inference_num_crops" , [
113
+ (None , None ),
114
+ (NUM_CROPS_OVERRIDE , None ),
115
+ (DEFAULT_NUM_CROPS , NUM_CROPS_OVERRIDE ),
116
+ ])
117
+ def test_input_processor_kwargs (use_processor_mock , init_num_crops ,
118
+ inference_num_crops ):
119
+ """Ensure input processors can use processor kwargs."""
120
+ dummy_registry = InputRegistry ()
121
+
122
+ init_kwargs , inference_kwargs , expected_seq_count = _get_num_crops_info (
123
+ init_num_crops , inference_num_crops )
106
124
107
- num_crops_val = processor (LLMInputs (prompt_token_ids = [], prompt = "" ))
108
- assert num_crops_val == expected_num_crops
125
+ ctx = build_model_context (DUMMY_MODEL_ID , mm_processor_kwargs = init_kwargs )
126
+ processor = dummy_registry .create_input_processor (ctx .model_config )
127
+ num_crops_val = processor (
128
+ LLMInputs (prompt_token_ids = [],
129
+ prompt = "" ,
130
+ mm_processor_kwargs = inference_kwargs ))
131
+ assert num_crops_val == expected_seq_count
109
132
110
133
111
134
@pytest .mark .parametrize (
@@ -124,11 +147,16 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
124
147
mm_processor_kwargs ):
125
148
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
126
149
dummy_registry = InputRegistry ()
150
+ # Should filter out the init time kwargs
127
151
ctx = build_model_context (DUMMY_MODEL_ID ,
128
152
mm_processor_kwargs = mm_processor_kwargs )
129
153
130
154
processor = dummy_registry .create_input_processor (ctx .model_config )
131
- num_crops_val = processor (LLMInputs (prompt_token_ids = [], prompt = "" ))
155
+ # Should filter out the inference time kwargs
156
+ num_crops_val = processor (
157
+ LLMInputs (prompt_token_ids = [],
158
+ prompt = "" ,
159
+ mm_processor_kwargs = mm_processor_kwargs ))
132
160
assert num_crops_val == DEFAULT_NUM_CROPS
133
161
134
162
@@ -271,32 +299,34 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
271
299
assert mapped_inputs ["pixel_values" ].shape [1 ] == num_crops + 1
272
300
273
301
274
- @pytest .mark .parametrize ("num_crops" , [None , NUM_CROPS_OVERRIDE ])
275
- def test_custom_mapper_kwarg_overrides (image_assets , num_crops ):
302
+ @pytest .mark .parametrize ("init_num_crops,inference_num_crops" , [
303
+ (None , None ),
304
+ (NUM_CROPS_OVERRIDE , None ),
305
+ (DEFAULT_NUM_CROPS , NUM_CROPS_OVERRIDE ),
306
+ ])
307
+ def test_custom_mapper_kwarg_overrides (image_assets , init_num_crops ,
308
+ inference_num_crops ):
276
309
"""Ensure custom mappers can use processor kwargs."""
277
- mm_processor_kwargs = None if num_crops is None else {
278
- "num_crops" : num_crops
279
- }
280
- expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
310
+ init_kwargs , inference_kwargs , expected_seq_count = _get_num_crops_info (
311
+ init_num_crops , inference_num_crops )
312
+
281
313
ctx = build_model_context (MULTIMODAL_MODEL_ID ,
282
314
trust_remote_code = True ,
283
- mm_processor_kwargs = mm_processor_kwargs ,
315
+ mm_processor_kwargs = init_kwargs ,
284
316
limit_mm_per_prompt = {"image" : 1 })
285
317
286
318
mm_registry = MultiModalRegistry ()
287
319
mm_registry .init_mm_limits_per_prompt (ctx .model_config )
288
- # Patch the image registry for phi3v with our lambda that is compatible
289
- # with overrides, then ensure that calling the method correctly echos
290
- # our num_crops value back from the mm_processor_kwargs.
291
320
image = image_assets [0 ].pil_image
292
321
mm_inputs = {"image" : image }
293
322
294
- with patch .object (
295
- mm_registry ._get_plugin ("image" ),
296
- "_default_input_mapper" ,
297
- {mm_model_cls (): custom_mapper },
298
- ):
299
- mapped_inputs = mm_registry .map_input (ctx .model_config , mm_inputs )
323
+ # Patch the image registry for phi3v with our lambda that is compatible
324
+ # with overrides, then ensure that calling the method correctly echos
325
+ # our num_crops value back from the mm_processor_kwargs.
326
+ mm_registry ._get_plugin ("image" ).register_input_mapper (custom_mapper )(
327
+ mm_model_cls ())
328
+ mapped_inputs = mm_registry .map_input (ctx .model_config , mm_inputs ,
329
+ inference_kwargs )
300
330
301
331
assert mapped_inputs ["pixel_values" ].shape [1 ] == expected_seq_count + 1
302
332
@@ -316,24 +346,24 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
316
346
def test_custom_mapper_with_sad_kwarg_overrides (image_assets ,
317
347
mm_processor_kwargs ):
318
348
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
349
+ # Should filter out the init time kwargs
319
350
ctx = build_model_context (MULTIMODAL_MODEL_ID ,
320
351
trust_remote_code = True ,
321
352
mm_processor_kwargs = mm_processor_kwargs ,
322
353
limit_mm_per_prompt = {"image" : 1 })
323
354
324
355
mm_registry = MultiModalRegistry ()
325
356
mm_registry .init_mm_limits_per_prompt (ctx .model_config )
326
- # Patch the image registry for phi3v with our lambda that is compatible
327
- # with overrides, then ensure that calling the method correctly echos
328
- # our num_crops value back from the mm_processor_kwargs.
329
357
image = image_assets [0 ].pil_image
330
358
mm_inputs = {"image" : image }
331
359
332
- with patch .object (
333
- mm_registry ._get_plugin ("image" ),
334
- "_default_input_mapper" ,
335
- {mm_model_cls (): custom_mapper },
336
- ):
337
- mapped_inputs = mm_registry .map_input (ctx .model_config , mm_inputs )
360
+ # Patch the image registry for phi3v with our lambda that is compatible
361
+ # with overrides, then ensure that calling the method correctly echos
362
+ # our num_crops value back from the mm_processor_kwargs.
363
+ mm_registry ._get_plugin ("image" ).register_input_mapper (custom_mapper )(
364
+ mm_model_cls ())
365
+ # Should filter out the inference time kwargs
366
+ mapped_inputs = mm_registry .map_input (
367
+ ctx .model_config , mm_inputs , mm_processor_kwargs = mm_processor_kwargs )
338
368
339
369
assert mapped_inputs ["pixel_values" ].shape [1 ] == DEFAULT_NUM_CROPS + 1
0 commit comments