|
8 | 8 | import torch.nn.functional as F
|
9 | 9 | from torch import Tensor
|
10 | 10 | from torch.nn.utils.rnn import pad_sequence
|
11 |
| -from transformers import (DataCollatorForSeq2Seq, PreTrainedTokenizerBase, |
12 |
| - StoppingCriteria) |
| 11 | +from transformers import PreTrainedTokenizerBase, StoppingCriteria |
13 | 12 |
|
14 | 13 | from swift.llm.agent.utils import calculate_loss_scale
|
15 | 14 |
|
@@ -187,10 +186,6 @@ def _init_template(self,
|
187 | 186 | self.truncation_strategy = truncation_strategy
|
188 | 187 | self.model = kwargs.get('model', None)
|
189 | 188 | self.use_loss_scale = kwargs.get('use_loss_scale', False)
|
190 |
| - self._data_collator = DataCollatorForSeq2Seq( |
191 |
| - tokenizer=self.tokenizer, |
192 |
| - label_pad_token_id=self.tokenizer.pad_token_id, |
193 |
| - ) |
194 | 189 | for key in [
|
195 | 190 | 'prefix', 'prompt', 'chat_sep', 'suffix', 'prefix_has_system'
|
196 | 191 | ]:
|
@@ -391,28 +386,55 @@ def concat_tokenizer_kwargs(
|
391 | 386 | assert len(old_tokenizer_kwargs) == 0
|
392 | 387 | return curr_tokenizer_kwargs
|
393 | 388 |
|
394 |
| - def data_collator( |
395 |
| - self, |
396 |
| - batch: List[Dict[str, Any]], |
397 |
| - pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]: |
| 389 | + def data_collator(self, |
| 390 | + batch: List[Dict[str, Any]], |
| 391 | + padding_to: Optional[int] = None) -> Dict[str, Any]: |
398 | 392 | """
|
399 | 393 | Args:
|
400 | 394 | batch(`List[Dict[str, Any]]`): The input data in batch
|
401 |
| - pad_to_multiple_of(`int`, optional): Whether padding to the multiple of an integer value. |
| 395 | + padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch |
| 396 | + will be padded to the `longest` |
402 | 397 | """
|
403 |
| - self._data_collator.pad_to_multiple_of = pad_to_multiple_of |
404 |
| - if pad_to_multiple_of: |
405 |
| - self.tokenizer.padding_side = 'right' |
406 |
| - loss_scale = [torch.tensor(b.pop('loss_scale')) |
| 398 | + tokenizer = self.tokenizer |
| 399 | + assert tokenizer.pad_token_id is not None |
| 400 | + input_ids = [torch.tensor(b['input_ids']) for b in batch] |
| 401 | + labels = [torch.tensor(b['labels']) for b in batch] |
| 402 | + loss_scale = [torch.tensor(b['loss_scale']) |
407 | 403 | for b in batch] if 'loss_scale' in batch[0] else None
|
408 |
| - res = self._data_collator(batch, return_tensors='pt') |
409 |
| - padding_to = res['input_ids'].shape[1] |
| 404 | + attention_mask = [ |
| 405 | + torch.ones(len(input_ids[i]), dtype=torch.int64) |
| 406 | + for i in range(len(input_ids)) |
| 407 | + ] |
| 408 | + |
| 409 | + if padding_to is not None: |
| 410 | + padding_len = padding_to - input_ids[0].shape[-1] |
| 411 | + if padding_len > 0: |
| 412 | + input_ids[0] = F.pad(input_ids[0], (0, padding_len), |
| 413 | + 'constant', tokenizer.pad_token_id) |
| 414 | + attention_mask[0] = F.pad(attention_mask[0], (0, padding_len), |
| 415 | + 'constant', 0) |
| 416 | + labels[0] = F.pad(labels[0], (0, padding_len), 'constant', |
| 417 | + -100) |
410 | 418 | if loss_scale:
|
411 | 419 | loss_scale[0] = F.pad(loss_scale[0],
|
412 |
| - (0, padding_to - loss_scale[0].shape[-1]), |
| 420 | + (0, padding_to - labels[0].shape[-1]), |
413 | 421 | 'constant', 0.)
|
| 422 | + |
| 423 | + input_ids = pad_sequence( |
| 424 | + input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) |
| 425 | + attention_mask = pad_sequence( |
| 426 | + attention_mask, batch_first=True, padding_value=0) |
| 427 | + if loss_scale: |
414 | 428 | loss_scale = pad_sequence(
|
415 | 429 | loss_scale, batch_first=True, padding_value=0.)
|
| 430 | + labels = pad_sequence(labels, batch_first=True, padding_value=-100) |
| 431 | + |
| 432 | + res = { |
| 433 | + 'input_ids': input_ids, |
| 434 | + 'attention_mask': attention_mask, |
| 435 | + 'labels': labels, |
| 436 | + } |
| 437 | + if loss_scale is not None: |
416 | 438 | res['loss_scale'] = loss_scale
|
417 | 439 | return res
|
418 | 440 |
|
@@ -579,11 +601,10 @@ def encode(
|
579 | 601 | inputs['images'] = image_tensor.to(model.dtype)
|
580 | 602 | return inputs, {}
|
581 | 603 |
|
582 |
| - def data_collator( |
583 |
| - self, |
584 |
| - batch: List[Dict[str, Any]], |
585 |
| - pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]: |
586 |
| - res = super().data_collator(batch, pad_to_multiple_of) |
| 604 | + def data_collator(self, |
| 605 | + batch: List[Dict[str, Any]], |
| 606 | + padding_to: Optional[int] = None) -> Dict[str, Any]: |
| 607 | + res = super().data_collator(batch, padding_to) |
587 | 608 | res['images'] = torch.concat([b['images'] for b in batch])
|
588 | 609 | return res
|
589 | 610 |
|
@@ -887,11 +908,10 @@ def encode(
|
887 | 908 | inputs['image_sizes'] = image_sizes
|
888 | 909 | return inputs, {}
|
889 | 910 |
|
890 |
| - def data_collator( |
891 |
| - self, |
892 |
| - batch: List[Dict[str, Any]], |
893 |
| - pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]: |
894 |
| - res = super().data_collator(batch, pad_to_multiple_of) |
| 911 | + def data_collator(self, |
| 912 | + batch: List[Dict[str, Any]], |
| 913 | + padding_to: Optional[int] = None) -> Dict[str, Any]: |
| 914 | + res = super().data_collator(batch, padding_to) |
895 | 915 | res['images'] = torch.concat([b['images'] for b in batch])
|
896 | 916 | res['image_sizes'] = sum([b['image_sizes'] for b in batch], start=[])
|
897 | 917 | return res
|
@@ -1073,11 +1093,10 @@ def encode(
|
1073 | 1093 | len(inputs['input_ids']) - len(token_type_ids))
|
1074 | 1094 | return inputs, {}
|
1075 | 1095 |
|
1076 |
| - def data_collator( |
1077 |
| - self, |
1078 |
| - batch: List[Dict[str, Any]], |
1079 |
| - pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]: |
1080 |
| - res = super().data_collator(batch, pad_to_multiple_of) |
| 1096 | + def data_collator(self, |
| 1097 | + batch: List[Dict[str, Any]], |
| 1098 | + padding_to: Optional[int] = None) -> Dict[str, Any]: |
| 1099 | + res = super().data_collator(batch, padding_to) |
1081 | 1100 | is_cogagent = 'cross_images' in batch[0]
|
1082 | 1101 | keys = ['images', 'cross_images'] if is_cogagent else ['images']
|
1083 | 1102 | for key in keys:
|
|
0 commit comments