@@ -929,7 +929,7 @@ def _load_image(img_path: Union[str, 'PIL.Image.Image']) -> 'PIL.Image.Image':
929
929
return image
930
930
931
931
932
- def _load_video (video_path : str ) -> np .ndarray :
932
+ def _load_video_llava (video_path : str ) -> np .ndarray :
933
933
import av
934
934
container = av .open (video_path )
935
935
total_frames = container .streams .video [0 ].frames
@@ -1036,7 +1036,7 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
1036
1036
if idx_list :
1037
1037
idx = idx_list [0 ]
1038
1038
images_path = example .get ('images' ) or []
1039
- image = _load_image (images_path [0 ])
1039
+ image = _read_batch (images_path ) [0 ]
1040
1040
placeholder = '<|begin_of_image|><|endoftext|><|end_of_image|>'
1041
1041
placeholder_id = self .tokenizer .encode (placeholder , add_special_tokens = False )
1042
1042
input_ids = (input_ids [:idx ] + placeholder_id + input_ids [idx + 1 :])
@@ -1627,7 +1627,7 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
1627
1627
else :
1628
1628
videos_path .append (media_file )
1629
1629
if len (videos_path ) > 0 :
1630
- videos = _read_batch (videos_path , _load_video )
1630
+ videos = _read_batch (videos_path , _load_video_llava )
1631
1631
video_processor = self .tokenizer .processor .video_processor
1632
1632
video_inputs = video_processor (videos , return_tensors = 'pt' ).to (self .model .dtype )
1633
1633
inputs ['pixel_values_videos' ] = video_inputs ['pixel_values_videos' ]
@@ -1766,9 +1766,9 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
1766
1766
if len (inputs ) == 0 :
1767
1767
return inputs , {}
1768
1768
image_path = example .get ('images' ) or []
1769
- if image_path :
1770
- raw_image = _load_image ( image_path [ 0 ])
1771
- pixel_values = self .tokenizer .processor .image_processor (raw_image , return_tensors = 'pt' )['pixel_values' ]
1769
+ raw_image = _read_batch ( image_path )
1770
+ if raw_image :
1771
+ pixel_values = self .tokenizer .processor .image_processor (raw_image [ 0 ] , return_tensors = 'pt' )['pixel_values' ]
1772
1772
inputs ['pixel_values' ] = pixel_values .to (self .model .dtype )
1773
1773
return inputs , {}
1774
1774
@@ -1806,9 +1806,9 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
1806
1806
inputs ['token_type_ids' ] = [0 ] * n + [1 ] * n2
1807
1807
else :
1808
1808
inputs ['token_type_ids' ] = [0 ] * len (inputs ['input_ids' ])
1809
- if image_path :
1810
- raw_image = _load_image ( image_path [ 0 ])
1811
- model_inputs = processor (text = example ['query' ], images = raw_image , return_tensors = 'pt' )
1809
+ raw_image = _read_batch ( image_path )
1810
+ if raw_image :
1811
+ model_inputs = processor (text = example ['query' ], images = raw_image [ 0 ] , return_tensors = 'pt' )
1812
1812
inputs ['pixel_values' ] = model_inputs ['pixel_values' ]
1813
1813
return inputs , {}
1814
1814
@@ -1949,9 +1949,9 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
1949
1949
example ['query' ], example ['history' ], images_path = replace_img_tag (example ['query' ], history ,
1950
1950
'<image_placeholder>' )
1951
1951
inputs , _ = super ().encode (example )
1952
- images_path .extend (example .get ('images' ) or [])
1953
1952
if len (inputs ) == 0 :
1954
1953
return inputs , {}
1954
+ images_path .extend (example .get ('images' ) or [])
1955
1955
images = _read_batch (images_path )
1956
1956
processor = self .tokenizer .processor
1957
1957
input_ids , labels = inputs ['input_ids' ], inputs ['labels' ]
@@ -2024,34 +2024,35 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa
2024
2024
2025
2025
def encode (self , example : Dict [str , Any ]) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
2026
2026
inputs , _ = super ().encode (example )
2027
- images_path = example .get ('images' ) or []
2028
- image = _load_image (images_path [0 ]) if len (images_path ) >= 1 else []
2029
2027
if len (inputs ) == 0 :
2030
2028
return inputs , {}
2029
+ images_path = example .get ('images' ) or []
2030
+ image = _read_batch (images_path )
2031
2031
inputs .pop ('loss_scale' , None )
2032
2032
model = self .model
2033
2033
inputs2 = model .build_conversation_input_ids (
2034
- self .tokenizer , query = example ['query' ], history = example .get ('history' ), images = [ image ] )
2035
- image_token_len = inputs2 ['token_type_ids' ].sum ()
2034
+ self .tokenizer , query = example ['query' ], history = example .get ('history' ), images = image )
2035
+ image_token_len = inputs2 ['token_type_ids' ].sum (). item ()
2036
2036
input_ids = inputs ['input_ids' ]
2037
2037
labels = inputs ['labels' ]
2038
2038
inputs ['token_type_ids' ] = [0 ] + [1 ] * image_token_len + [0 ] * len (input_ids [1 :])
2039
2039
inputs ['input_ids' ] = input_ids [:1 ] + [self .tokenizer .pad_token_id ] * image_token_len + input_ids [1 :]
2040
2040
if labels is not None :
2041
2041
inputs ['labels' ] = labels [:1 ] + [- 100 ] * image_token_len + labels [1 :]
2042
- dtype = model .dtype
2043
- inputs ['images' ] = [[img .to (dtype = dtype )] for img in inputs2 ['images' ]]
2044
- if 'cross_images' in inputs2 :
2045
- # is cogagent
2046
- inputs ['cross_images' ] = [[cross_img .to (dtype = dtype )] for cross_img in inputs2 ['cross_images' ]]
2042
+ if len (image ) > 0 :
2043
+ dtype = model .dtype
2044
+ inputs ['images' ] = [[img .to (dtype = dtype )] for img in inputs2 ['images' ]]
2045
+ if 'cross_images' in inputs2 :
2046
+ # is cogagent
2047
+ inputs ['cross_images' ] = [[cross_img .to (dtype = dtype )] for cross_img in inputs2 ['cross_images' ]]
2047
2048
return inputs , {}
2048
2049
2049
2050
def data_collator (self , batch : List [Dict [str , Any ]], padding_to : Optional [int ] = None ) -> Dict [str , Any ]:
2050
2051
res = super ().data_collator (batch , padding_to )
2051
- is_cogagent = 'cross_images' in batch [0 ]
2052
- keys = ['images' , 'cross_images' ] if is_cogagent else ['images' ]
2052
+ keys = ['images' , 'cross_images' ]
2053
2053
for key in keys :
2054
- res [key ] = [b [key ][0 ] for b in batch ]
2054
+ if key in batch [0 ]:
2055
+ res [key ] = [b [key ][0 ] for b in batch ]
2055
2056
token_type_ids = [torch .tensor (b ['token_type_ids' ]) for b in batch ]
2056
2057
token_type_ids = pad_sequence (token_type_ids , batch_first = True , padding_value = 0 )
2057
2058
res ['token_type_ids' ] = token_type_ids
@@ -2107,10 +2108,10 @@ def check_example(self, example):
2107
2108
2108
2109
def encode (self , example : Dict [str , Any ]) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
2109
2110
inputs , _ = super ().encode (example )
2110
- images_path = example ['images' ]
2111
- image = _load_image (images_path [0 ])
2112
2111
if len (inputs ) == 0 :
2113
2112
return inputs , {}
2113
+ images_path = example ['images' ]
2114
+ image = _load_image (images_path [0 ])
2114
2115
input_ids = inputs ['input_ids' ]
2115
2116
labels = inputs ['labels' ]
2116
2117
idx_list = _findall (input_ids , - 1 )
0 commit comments