Skip to content

Commit af58c7f

Browse files
committed
Allow n-tuples for CE MarginMSE training
1 parent 6e7d64e commit af58c7f

File tree

2 files changed

+51
-23
lines changed

2 files changed

+51
-23
lines changed

sentence_transformers/cross_encoder/losses/MarginMSELoss.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import torch
34
from torch import Tensor, nn
45

56
from sentence_transformers.cross_encoder.CrossEncoder import CrossEncoder
@@ -92,37 +93,65 @@ def compute_labels(batch):
9293
f"but got a model with {self.model.num_labels} output labels."
9394
)
9495

95-
def forward(self, inputs: list[list[str]], labels: Tensor) -> Tensor:
96-
if len(inputs) != 3:
96+
def forward(self, inputs: list[list[str]], labels: Tensor | list[Tensor]) -> Tensor:
97+
anchors = inputs[0]
98+
positives = inputs[1]
99+
negatives = inputs[2:]
100+
batch_size = len(anchors)
101+
102+
# If there's multiple scores, then `labels` is a list of tensors. We need to stack them into
103+
# a single tensor of shape (batch_size, num_columns - 1)
104+
if isinstance(labels, list):
105+
labels = torch.stack(labels, dim=1).T
106+
107+
if labels.shape == (batch_size, len(negatives) + 1):
108+
# If labels are given as a single score for positive and multiple negatives,
109+
# we need to adjust the labels to be the difference between positive and negatives
110+
labels = labels[:, 0].unsqueeze(1) - labels[:, 1:]
111+
112+
# Ensure the shape is (batch_size, num_negatives)
113+
if labels.shape == (batch_size,):
114+
labels = labels.unsqueeze(1)
115+
116+
if labels.shape != (batch_size, len(negatives)):
97117
raise ValueError(
98-
f"MSELoss expects a dataset with three non-label columns, but got a dataset with {len(inputs)} columns."
118+
f"Labels shape {labels.shape} does not match expected shape {(batch_size, len(negatives))}. "
119+
"Ensure that your dataset labels/scores are 1) lists of differences between positive scores and "
120+
"negatives scores (length `num_negatives`), or 2) lists of positive and negative scores "
121+
"(length `num_negatives + 1`)."
99122
)
100123

101-
positive_pairs = list(zip(inputs[0], inputs[1]))
102-
tokens = self.model.tokenizer(
103-
positive_pairs,
104-
padding=True,
105-
truncation=True,
106-
return_tensors="pt",
107-
)
108-
tokens.to(self.model.device)
109-
positive_logits = self.model(**tokens)[0].view(-1)
110-
positive_logits = self.activation_fn(positive_logits)
124+
positive_pairs = list(zip(anchors, positives))
125+
positive_logits = self.logits_from_pairs(positive_pairs)
126+
negative_logits_list = []
127+
for negative in negatives:
128+
negative_pairs = list(zip(anchors, negative))
129+
negative_logits_list.append(self.logits_from_pairs(negative_pairs))
111130

112-
negative_pairs = list(zip(inputs[0], inputs[2]))
131+
margin_logits = [positive_logits - negative_logits for negative_logits in negative_logits_list]
132+
margin_logits = torch.stack(margin_logits, dim=1)
133+
loss = self.loss_fct(margin_logits, labels.float())
134+
return loss
135+
136+
def logits_from_pairs(self, pairs: list[tuple[str, str]]) -> Tensor:
137+
"""
138+
Computes the logits for a list of pairs using the model.
139+
140+
Args:
141+
pairs (list[tuple[str, str]]): A list of pairs of strings (query, passage).
142+
143+
Returns:
144+
Tensor: The logits for the pairs.
145+
"""
113146
tokens = self.model.tokenizer(
114-
negative_pairs,
147+
pairs,
115148
padding=True,
116149
truncation=True,
117150
return_tensors="pt",
118151
)
119152
tokens.to(self.model.device)
120-
negative_logits = self.model(**tokens)[0].view(-1)
121-
negative_logits = self.activation_fn(negative_logits)
122-
123-
margin_logits = positive_logits - negative_logits
124-
loss = self.loss_fct(margin_logits, labels.float())
125-
return loss
153+
logits = self.model(**tokens)[0].view(-1)
154+
return self.activation_fn(logits)
126155

127156
def get_config_dict(self):
128157
return {

sentence_transformers/losses/MarginMSELoss.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,8 @@ def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor)
187187
# we need to adjust the labels to be the difference between positive and negatives
188188
labels = labels[:, 0].unsqueeze(1) - labels[:, 1:]
189189

190+
# Ensure the shape is (batch_size, num_negatives)
190191
if labels.shape == (batch_size,):
191-
# If labels are given as a single score for positive and multiple negatives,
192-
# we need to adjust the labels to be the difference between positive and negatives
193192
labels = labels.unsqueeze(1)
194193

195194
if labels.shape != (batch_size, len(embeddings_negs)):

0 commit comments

Comments
 (0)