@@ -1081,6 +1081,8 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
1081
1081
image = _read_from_path (image_path )
1082
1082
images .append (image )
1083
1083
inputs , _ = super ().encode (example )
1084
+ if len (inputs ) == 0 :
1085
+ return inputs , {}
1084
1086
input_ids = inputs ['input_ids' ]
1085
1087
labels = inputs ['labels' ]
1086
1088
idx_list = _findall (input_ids , 1 )[1 :] # 1: <s>
@@ -1330,7 +1332,16 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
1330
1332
register_template (TemplateType .minicpm , Template (['<s>{{SYSTEM}}' ], ['<用户>{{QUERY}}<AI>' ], [], ['</s>' ]))
1331
1333
1332
1334
1333
- class MiniCPMVTemlate (Template ):
1335
+ def _remove_idx (arr : List [int ], idx_list : List [int ]) -> List [int ]:
1336
+ res = []
1337
+ idx_set = set (idx_list )
1338
+ for i , x in enumerate (arr ):
1339
+ if i not in idx_set :
1340
+ res .append (x )
1341
+ return res
1342
+
1343
+
1344
+ class MiniCPMVTemplate (Template ):
1334
1345
1335
1346
def __init__ (self , * args , ** kwargs ):
1336
1347
self .is_v2_5 = kwargs .pop ('is_v2_5' , False )
@@ -1345,32 +1356,22 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
1345
1356
return inputs , {}
1346
1357
input_ids = inputs ['input_ids' ]
1347
1358
labels = inputs ['labels' ]
1348
-
1349
- img_start_idxs = np .where (np .array (input_ids ) == self .tokenizer .im_start_id )[0 ]
1350
- if len (img_start_idxs ) > 1 : # if mutli-round, input_ids have mutli <image><unk></image>\n
1351
- start = 0
1352
- new_input_ids = []
1353
- new_labels = []
1354
- for idx in img_start_idxs [1 :]:
1355
- new_input_ids = new_input_ids + input_ids [start :idx ]
1356
- if labels is not None :
1357
- new_labels = new_labels + labels [start :idx ]
1358
- start = idx + 4 # skip <image><unk></image>\n
1359
- new_input_ids = new_input_ids + input_ids [start :]
1360
- input_ids = new_input_ids
1359
+ idx_list = _findall (input_ids , - 1 )
1360
+ if len (idx_list ) >= 2 :
1361
+ input_ids = _remove_idx (input_ids , idx_list [1 :])
1361
1362
if labels is not None :
1362
- new_labels = new_labels + labels [start :]
1363
- labels = new_labels
1364
-
1365
- idx = img_start_idxs [0 ] + 1 # first <unk>
1363
+ labels = _remove_idx (labels , idx_list [1 :])
1364
+ idx = idx_list [0 ]
1366
1365
config = self .model .config
1367
1366
tgt_sizes = None
1368
- if config .slice_mode :
1367
+ slice_mode = getattr (config , 'slice_mode' , False )
1368
+ if slice_mode :
1369
1369
images , placeholder = self .model .get_slice_image_placeholder (image , self .tokenizer )
1370
+ placeholder += '\n '
1370
1371
placeholder_id = self .tokenizer .encode (placeholder , add_special_tokens = False )
1371
- input_ids = (input_ids [:idx - 1 ] + placeholder_id + input_ids [idx + 2 :])
1372
+ input_ids = (input_ids [:idx ] + placeholder_id + input_ids [idx + 1 :])
1372
1373
if labels is not None :
1373
- labels = (labels [:idx - 1 ] + [- 100 ] * len (placeholder_id ) + labels [idx + 2 :])
1374
+ labels = (labels [:idx ] + [- 100 ] * len (placeholder_id ) + labels [idx + 1 :])
1374
1375
input_tensor_ids = torch .tensor (input_ids )
1375
1376
image_start_idx = torch .where (input_tensor_ids == self .tokenizer .im_start_id )[0 ]
1376
1377
image_start_idx += 1
@@ -1393,9 +1394,11 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
1393
1394
else :
1394
1395
pixel_values = [self .model .transform (img ).to (device = self .model .device ) for img in images ]
1395
1396
else :
1396
- input_ids = (input_ids [:idx ] + [self .tokenizer .unk_token_id ] * config .query_num + input_ids [idx + 1 :])
1397
+ placeholder = '<image>' + '<unk>' * config .query_num + '</image>\n '
1398
+ placeholder_id = self .tokenizer .encode (placeholder , add_special_tokens = False )
1399
+ input_ids = (input_ids [:idx ] + placeholder_id + input_ids [idx + 1 :])
1397
1400
if labels is not None :
1398
- labels = (labels [:idx ] + [- 100 ] * config . query_num + labels [idx + 1 :])
1401
+ labels = (labels [:idx ] + [- 100 ] * len ( placeholder_id ) + labels [idx + 1 :])
1399
1402
image_bound = [torch .tensor ([[idx , idx + config .query_num ]])]
1400
1403
pixel_values = [self .model .transform (image ).to (device = self .model .device )]
1401
1404
data = {
@@ -1418,7 +1421,7 @@ def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]:
1418
1421
1419
1422
register_template (
1420
1423
TemplateType .minicpm_v ,
1421
- MiniCPMVTemlate (['<s>{{SYSTEM}}' ], ['<用户><image><unk></image> \n {{QUERY}}<AI>' ], [], ['</s>' ]),
1424
+ MiniCPMVTemplate (['<s>{{SYSTEM}}' ], ['<用户>' , [ - 1 ], ' {{QUERY}}<AI>' ], [], ['</s>' ]),
1422
1425
use_model = True ,
1423
1426
lazy_tokenize = True ,
1424
1427
infer_media_type = 'dialogue' ,
@@ -1427,11 +1430,11 @@ def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]:
1427
1430
1428
1431
register_template (
1429
1432
TemplateType .minicpm_v_v2_5 ,
1430
- MiniCPMVTemlate (['<|begin_of_text|>{{SYSTEM}}' ], [
1431
- '<|start_header_id|>user<|end_header_id|>\n \n <image><unk></image> \n {{QUERY}}<|eot_id|>'
1433
+ MiniCPMVTemplate (['<|begin_of_text|>{{SYSTEM}}' ], [
1434
+ '<|start_header_id|>user<|end_header_id|>\n \n ' , [ - 1 ], ' {{QUERY}}<|eot_id|>'
1432
1435
'<|start_header_id|>assistant<|end_header_id|>\n \n '
1433
1436
], ['<|eot_id|>' ], ['<|eot_id|>' ],
1434
- is_v2_5 = True ),
1437
+ is_v2_5 = True ),
1435
1438
use_model = True ,
1436
1439
lazy_tokenize = True ,
1437
1440
infer_media_type = 'dialogue' ,
0 commit comments