Skip to content

Commit 985eea3

Browse files
Fix data_collator (#674)
1 parent 257632e commit 985eea3

File tree

2 files changed

+54
-36
lines changed

2 files changed

+54
-36
lines changed

swift/llm/sft.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
182182
if val_dataset is not None:
183183
val_dataset = LazyLLMDataset(val_dataset, template)
184184

185-
pad_to_multiple_of = 8 if args.sft_type == 'longlora' else None
186-
data_collator = partial(
187-
template.data_collator, pad_to_multiple_of=pad_to_multiple_of)
185+
padding_to = args.max_length if args.sft_type == 'longlora' else None
186+
data_collator = partial(template.data_collator, padding_to=padding_to)
188187

189188
# Trainer
190189
logger.info(f'training_args: {training_args}')

swift/llm/utils/template.py

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import torch.nn.functional as F
99
from torch import Tensor
1010
from torch.nn.utils.rnn import pad_sequence
11-
from transformers import (DataCollatorForSeq2Seq, PreTrainedTokenizerBase,
12-
StoppingCriteria)
11+
from transformers import PreTrainedTokenizerBase, StoppingCriteria
1312

1413
from swift.llm.agent.utils import calculate_loss_scale
1514

@@ -187,10 +186,6 @@ def _init_template(self,
187186
self.truncation_strategy = truncation_strategy
188187
self.model = kwargs.get('model', None)
189188
self.use_loss_scale = kwargs.get('use_loss_scale', False)
190-
self._data_collator = DataCollatorForSeq2Seq(
191-
tokenizer=self.tokenizer,
192-
label_pad_token_id=self.tokenizer.pad_token_id,
193-
)
194189
for key in [
195190
'prefix', 'prompt', 'chat_sep', 'suffix', 'prefix_has_system'
196191
]:
@@ -391,28 +386,55 @@ def concat_tokenizer_kwargs(
391386
assert len(old_tokenizer_kwargs) == 0
392387
return curr_tokenizer_kwargs
393388

394-
def data_collator(
395-
self,
396-
batch: List[Dict[str, Any]],
397-
pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]:
389+
def data_collator(self,
390+
batch: List[Dict[str, Any]],
391+
padding_to: Optional[int] = None) -> Dict[str, Any]:
398392
"""
399393
Args:
400394
batch(`List[Dict[str, Any]]`): The input data in batch
401-
pad_to_multiple_of(`int`, optional): Whether padding to the multiple of an integer value.
395+
padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch
396+
will be padded to the `longest`
402397
"""
403-
self._data_collator.pad_to_multiple_of = pad_to_multiple_of
404-
if pad_to_multiple_of:
405-
self.tokenizer.padding_side = 'right'
406-
loss_scale = [torch.tensor(b.pop('loss_scale'))
398+
tokenizer = self.tokenizer
399+
assert tokenizer.pad_token_id is not None
400+
input_ids = [torch.tensor(b['input_ids']) for b in batch]
401+
labels = [torch.tensor(b['labels']) for b in batch]
402+
loss_scale = [torch.tensor(b['loss_scale'])
407403
for b in batch] if 'loss_scale' in batch[0] else None
408-
res = self._data_collator(batch, return_tensors='pt')
409-
padding_to = res['input_ids'].shape[1]
404+
attention_mask = [
405+
torch.ones(len(input_ids[i]), dtype=torch.int64)
406+
for i in range(len(input_ids))
407+
]
408+
409+
if padding_to is not None:
410+
padding_len = padding_to - input_ids[0].shape[-1]
411+
if padding_len > 0:
412+
input_ids[0] = F.pad(input_ids[0], (0, padding_len),
413+
'constant', tokenizer.pad_token_id)
414+
attention_mask[0] = F.pad(attention_mask[0], (0, padding_len),
415+
'constant', 0)
416+
labels[0] = F.pad(labels[0], (0, padding_len), 'constant',
417+
-100)
410418
if loss_scale:
411419
loss_scale[0] = F.pad(loss_scale[0],
412-
(0, padding_to - loss_scale[0].shape[-1]),
420+
(0, padding_to - labels[0].shape[-1]),
413421
'constant', 0.)
422+
423+
input_ids = pad_sequence(
424+
input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
425+
attention_mask = pad_sequence(
426+
attention_mask, batch_first=True, padding_value=0)
427+
if loss_scale:
414428
loss_scale = pad_sequence(
415429
loss_scale, batch_first=True, padding_value=0.)
430+
labels = pad_sequence(labels, batch_first=True, padding_value=-100)
431+
432+
res = {
433+
'input_ids': input_ids,
434+
'attention_mask': attention_mask,
435+
'labels': labels,
436+
}
437+
if loss_scale is not None:
416438
res['loss_scale'] = loss_scale
417439
return res
418440

@@ -579,11 +601,10 @@ def encode(
579601
inputs['images'] = image_tensor.to(model.dtype)
580602
return inputs, {}
581603

582-
def data_collator(
583-
self,
584-
batch: List[Dict[str, Any]],
585-
pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]:
586-
res = super().data_collator(batch, pad_to_multiple_of)
604+
def data_collator(self,
605+
batch: List[Dict[str, Any]],
606+
padding_to: Optional[int] = None) -> Dict[str, Any]:
607+
res = super().data_collator(batch, padding_to)
587608
res['images'] = torch.concat([b['images'] for b in batch])
588609
return res
589610

@@ -887,11 +908,10 @@ def encode(
887908
inputs['image_sizes'] = image_sizes
888909
return inputs, {}
889910

890-
def data_collator(
891-
self,
892-
batch: List[Dict[str, Any]],
893-
pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]:
894-
res = super().data_collator(batch, pad_to_multiple_of)
911+
def data_collator(self,
912+
batch: List[Dict[str, Any]],
913+
padding_to: Optional[int] = None) -> Dict[str, Any]:
914+
res = super().data_collator(batch, padding_to)
895915
res['images'] = torch.concat([b['images'] for b in batch])
896916
res['image_sizes'] = sum([b['image_sizes'] for b in batch], start=[])
897917
return res
@@ -1073,11 +1093,10 @@ def encode(
10731093
len(inputs['input_ids']) - len(token_type_ids))
10741094
return inputs, {}
10751095

1076-
def data_collator(
1077-
self,
1078-
batch: List[Dict[str, Any]],
1079-
pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]:
1080-
res = super().data_collator(batch, pad_to_multiple_of)
1096+
def data_collator(self,
1097+
batch: List[Dict[str, Any]],
1098+
padding_to: Optional[int] = None) -> Dict[str, Any]:
1099+
res = super().data_collator(batch, padding_to)
10811100
is_cogagent = 'cross_images' in batch[0]
10821101
keys = ['images', 'cross_images'] if is_cogagent else ['images']
10831102
for key in keys:

0 commit comments

Comments
 (0)