@@ -758,7 +758,14 @@ def __init__(
758758 else :
759759 self .completion_only_loss = args .completion_only_loss
760760
761- if data_collator is None and not self ._is_vlm :
761+ self ._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample
762+ if self ._is_vision_dataset and not self ._is_vlm :
763+ raise ValueError (
764+ "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
765+ "model does not seem to be a vision-language model. Please check your model and dataset."
766+ )
767+
768+ if data_collator is None and not self ._is_vision_dataset :
762769 # Get the pad token: if not provided, use the one from the processing class or the eos token
763770 # if the processing class does not have a pad token.
764771 pad_token = args .pad_token or tokenizer .pad_token or tokenizer .eos_token
@@ -777,7 +784,7 @@ def __init__(
777784 return_position_ids = use_flash_attention ,
778785 pad_to_multiple_of = args .pad_to_multiple_of ,
779786 )
780- elif data_collator is None and self ._is_vlm :
787+ elif data_collator is None and self ._is_vision_dataset :
781788 data_collator = DataCollatorForVisionLanguageModeling (
782789 processor = processing_class ,
783790 max_length = args .max_length ,
@@ -805,7 +812,9 @@ def __init__(
805812 # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where
806813 # preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead.
807814 skip_prepare_dataset = (
808- args .dataset_kwargs is not None and args .dataset_kwargs .get ("skip_prepare_dataset" , False ) or self ._is_vlm
815+ args .dataset_kwargs is not None
816+ and args .dataset_kwargs .get ("skip_prepare_dataset" , False )
817+ or self ._is_vision_dataset
809818 )
810819 if not skip_prepare_dataset :
811820 if self .completion_only_loss and formatting_func :
@@ -959,22 +968,36 @@ def add_eos(example, eos_token):
959968 if isinstance (dataset , Dataset ): # `IterableDataset.map` does not support `desc`
960969 map_kwargs ["desc" ] = f"Tokenizing { dataset_name } dataset"
961970
962- def tokenize (example , processing_class , dataset_text_field , assistant_only_loss ):
971+ def tokenize_fn (example , processing_class , dataset_text_field , assistant_only_loss ):
963972 if "prompt" in example : # prompt-completion case
964973 output = {}
965974 if is_conversational (example ):
975+ if self ._is_vlm :
976+ prepare_multimodal_messages (example ["prompt" ], num_images = 0 )
977+ prepare_multimodal_messages (example ["completion" ], num_images = 0 )
966978 prompt_ids = processing_class .apply_chat_template (
967979 example ["prompt" ],
980+ tokenize = True ,
968981 tools = example .get ("tools" ),
969982 ** example .get ("chat_template_kwargs" , {}),
970983 )
984+ # Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
985+ # even for single examples, while for LLMs it returns lists of ints.
986+ prompt_ids = prompt_ids [0 ] if isinstance (prompt_ids [0 ], list ) else prompt_ids
971987 prompt_completion_processed = processing_class .apply_chat_template (
972988 example ["prompt" ] + example ["completion" ],
973989 return_dict = True ,
990+ tokenize = True ,
974991 return_assistant_tokens_mask = assistant_only_loss ,
975992 tools = example .get ("tools" ),
976993 ** example .get ("chat_template_kwargs" , {}),
977994 )
995+ # Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
996+ # even for single examples, while for LLMs it returns lists of ints.
997+ prompt_completion_processed = {
998+ k : v [0 ] if isinstance (v [0 ], list ) else v
999+ for k , v in prompt_completion_processed .items ()
1000+ }
9781001 prompt_completion_ids = prompt_completion_processed ["input_ids" ]
9791002 if "assistant_masks" in prompt_completion_processed :
9801003 output ["assistant_masks" ] = prompt_completion_processed ["assistant_masks" ]
@@ -999,13 +1022,19 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss)
9991022
10001023 else : # language modeling case
10011024 if is_conversational (example ):
1025+ if self ._is_vlm :
1026+ prepare_multimodal_messages (example ["messages" ], num_images = 0 )
10021027 processed = processing_class .apply_chat_template (
10031028 example ["messages" ],
10041029 return_dict = True ,
1030+ tokenize = True ,
10051031 return_assistant_tokens_mask = assistant_only_loss ,
10061032 tools = example .get ("tools" ),
10071033 ** example .get ("chat_template_kwargs" , {}),
10081034 )
1035+ # Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
1036+ # even for single examples, while for LLMs it returns lists of ints.
1037+ processed = {k : v [0 ] if isinstance (v [0 ], list ) else v for k , v in processed .items ()}
10091038 if "assistant_masks" in processed and 1 not in processed ["assistant_masks" ]:
10101039 raise RuntimeError (
10111040 "You're using `assistant_only_loss=True`, but at least one example has no "
@@ -1020,7 +1049,7 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss)
10201049 return output
10211050
10221051 dataset = dataset .map (
1023- tokenize ,
1052+ tokenize_fn ,
10241053 fn_kwargs = {
10251054 "processing_class" : processing_class ,
10261055 "dataset_text_field" : args .dataset_text_field ,
@@ -1064,7 +1093,7 @@ def _set_signature_columns_if_needed(self):
10641093 # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the
10651094 # dataset. So we need to override the default signature columns to include "completion_mask" as well.
10661095 if self ._signature_columns is None :
1067- if self ._is_vlm :
1096+ if self ._is_vision_dataset :
10681097 self ._signature_columns = ["messages" , "prompt" , "completion" , "images" ]
10691098 else :
10701099 self ._signature_columns = ["input_ids" , "labels" , "seq_lengths" , "completion_mask" , "assistant_masks" ]
0 commit comments