|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import torch |
3 | 4 | from torch import Tensor, nn
|
4 | 5 |
|
5 | 6 | from sentence_transformers.cross_encoder.CrossEncoder import CrossEncoder
|
@@ -92,37 +93,65 @@ def compute_labels(batch):
|
92 | 93 | f"but got a model with {self.model.num_labels} output labels."
|
93 | 94 | )
|
94 | 95 |
|
95 |
| - def forward(self, inputs: list[list[str]], labels: Tensor) -> Tensor: |
96 |
| - if len(inputs) != 3: |
| 96 | + def forward(self, inputs: list[list[str]], labels: Tensor | list[Tensor]) -> Tensor: |
| 97 | + anchors = inputs[0] |
| 98 | + positives = inputs[1] |
| 99 | + negatives = inputs[2:] |
| 100 | + batch_size = len(anchors) |
| 101 | + |
| 102 | + # If there's multiple scores, then `labels` is a list of tensors. We need to stack them into |
| 103 | + # a single tensor of shape (batch_size, num_columns - 1) |
| 104 | + if isinstance(labels, list): |
| 105 | + labels = torch.stack(labels, dim=1).T |
| 106 | + |
| 107 | + if labels.shape == (batch_size, len(negatives) + 1): |
| 108 | + # If labels are given as a single score for positive and multiple negatives, |
| 109 | + # we need to adjust the labels to be the difference between positive and negatives |
| 110 | + labels = labels[:, 0].unsqueeze(1) - labels[:, 1:] |
| 111 | + |
| 112 | + # Ensure the shape is (batch_size, num_negatives) |
| 113 | + if labels.shape == (batch_size,): |
| 114 | + labels = labels.unsqueeze(1) |
| 115 | + |
| 116 | + if labels.shape != (batch_size, len(negatives)): |
97 | 117 | raise ValueError(
|
98 |
| - f"MSELoss expects a dataset with three non-label columns, but got a dataset with {len(inputs)} columns." |
| 118 | + f"Labels shape {labels.shape} does not match expected shape {(batch_size, len(negatives))}. " |
| 119 | + "Ensure that your dataset labels/scores are 1) lists of differences between positive scores and " |
| 120 | + "negatives scores (length `num_negatives`), or 2) lists of positive and negative scores " |
| 121 | + "(length `num_negatives + 1`)." |
99 | 122 | )
|
100 | 123 |
|
101 |
| - positive_pairs = list(zip(inputs[0], inputs[1])) |
102 |
| - tokens = self.model.tokenizer( |
103 |
| - positive_pairs, |
104 |
| - padding=True, |
105 |
| - truncation=True, |
106 |
| - return_tensors="pt", |
107 |
| - ) |
108 |
| - tokens.to(self.model.device) |
109 |
| - positive_logits = self.model(**tokens)[0].view(-1) |
110 |
| - positive_logits = self.activation_fn(positive_logits) |
| 124 | + positive_pairs = list(zip(anchors, positives)) |
| 125 | + positive_logits = self.logits_from_pairs(positive_pairs) |
| 126 | + negative_logits_list = [] |
| 127 | + for negative in negatives: |
| 128 | + negative_pairs = list(zip(anchors, negative)) |
| 129 | + negative_logits_list.append(self.logits_from_pairs(negative_pairs)) |
111 | 130 |
|
112 |
| - negative_pairs = list(zip(inputs[0], inputs[2])) |
| 131 | + margin_logits = [positive_logits - negative_logits for negative_logits in negative_logits_list] |
| 132 | + margin_logits = torch.stack(margin_logits, dim=1) |
| 133 | + loss = self.loss_fct(margin_logits, labels.float()) |
| 134 | + return loss |
| 135 | + |
| 136 | + def logits_from_pairs(self, pairs: list[tuple[str, str]]) -> Tensor: |
| 137 | + """ |
| 138 | + Computes the logits for a list of pairs using the model. |
| 139 | +
|
| 140 | + Args: |
| 141 | + pairs (list[tuple[str, str]]): A list of pairs of strings (query, passage). |
| 142 | +
|
| 143 | + Returns: |
| 144 | + Tensor: The logits for the pairs. |
| 145 | + """ |
113 | 146 | tokens = self.model.tokenizer(
|
114 |
| - negative_pairs, |
| 147 | + pairs, |
115 | 148 | padding=True,
|
116 | 149 | truncation=True,
|
117 | 150 | return_tensors="pt",
|
118 | 151 | )
|
119 | 152 | tokens.to(self.model.device)
|
120 |
| - negative_logits = self.model(**tokens)[0].view(-1) |
121 |
| - negative_logits = self.activation_fn(negative_logits) |
122 |
| - |
123 |
| - margin_logits = positive_logits - negative_logits |
124 |
| - loss = self.loss_fct(margin_logits, labels.float()) |
125 |
| - return loss |
| 153 | + logits = self.model(**tokens)[0].view(-1) |
| 154 | + return self.activation_fn(logits) |
126 | 155 |
|
127 | 156 | def get_config_dict(self):
|
128 | 157 | return {
|
|
0 commit comments