Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def split_multi_medias(_inputs):
positive_encoded = self._encode_truncated(positive)
for key in positive_encoded:
_encoded[f'positive_{key}'] = positive_encoded[key]
_encoded[f'negative_{key}'] = []
labels.append(float(inputs.label) if inputs.label is not None else 1.0)

rejected_len = len(inputs.rejected_response) if inputs.rejected_response else 0
Expand All @@ -381,7 +382,7 @@ def split_multi_medias(_inputs):
split_multi_medias(negative)
negative_encoded = self._encode_truncated(negative)
for key in negative_encoded:
_encoded[f'negative{i}_{key}'] = negative_encoded[key]
_encoded[f'negative_{key}'].append(negative_encoded[key])
labels.append(0.0)

_encoded['labels'] = labels
Expand Down Expand Up @@ -1314,10 +1315,18 @@ def _embedding_data_collator(self,
new_batch = []
for b in batch:
keys = [key for key in b.keys() if 'negative' in key]
max_neg = max([int(re.findall(r'negative(-?\d+)', key)[0]) for key in keys]) if keys else None
max_neg = None
for key in keys:
value_list = b[key]
suffix = key[len('negative_'):]
max_neg = len(value_list)
for i, value in enumerate(value_list):
b[f'negative{i}_{suffix}'] = value
b.pop(key)

indexes = ['anchor_', 'positive_']
if max_neg is not None:
for i in range(0, max_neg + 1):
for i in range(0, max_neg):
indexes.append(f'negative{i}_')
for prefix in indexes:
new_batch += self._fetch_inputs_startswith([b], prefix)
Expand Down
21 changes: 21 additions & 0 deletions swift/trainers/sequence_parallel/ulysses.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,27 @@ def prepare_trainer(self, trainer):
trainer._get_per_token_logps = MethodType(_get_per_token_logps, trainer)
trainer.split_by_mini_batches = MethodType(split_by_mini_batches, trainer)

class DataloaderWrap:

def __init__(self, dataloader):
self.dataloader = dataloader

def __getattr__(self, item):
return getattr(self.dataloader, item)

def __len__(wrapped):
return len(wrapped.dataloader) * self.sp_world_size

def __iter__(self):
yield from self.dataloader

def get_train_dataloader(trainer):
dataloader = trainer.get_origin_train_dataloader()
return DataloaderWrap(dataloader)

trainer.get_origin_train_dataloader = trainer.get_train_dataloader
trainer.get_train_dataloader = MethodType(get_train_dataloader, trainer)

from swift.plugin import metric
from swift.trainers import mixin
compute_acc_origin = metric.compute_acc
Expand Down