Skip to content

Commit 2edf3e1

Browse files
committed
add reduce option
1 parent 6c3d4c1 commit 2edf3e1

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

src/pytorch_metric_learning/losses/smooth_ap.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,20 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
8080
)
8181
sims_pos_ranks = torch.sum(sims_pos_sigm, dim=-1) + 1
8282

83-
ap = torch.zeros(1).to(embeddings.device)
8483
g = batch_size // num_classes_batch
84+
ap = torch.zeros(batch_size).to(embeddings.device)
8585
for i in range(num_classes_batch):
86-
pos_divide = torch.sum(
87-
sims_pos_ranks[i] / sims_ranks[i * g : (i + 1) * g, i * g : (i + 1) * g]
88-
)
89-
ap = ap + (pos_divide / g) / batch_size
86+
for j in range(g):
87+
pos_rank = sims_pos_ranks[i, j]
88+
all_rank = sims_ranks[i * g + j, i * g: (i + 1) * g]
89+
ap[i * g + j] = torch.sum(pos_rank / all_rank) / g
9090

9191
loss = 1 - ap
92+
9293
return {
93-
"loss": {
94+
"ap_loss": {
9495
"losses": loss,
95-
"indices": None,
96-
"reduction_type": "already_reduced",
96+
"indices": c_f.torch_arange_from_size(loss),
97+
"reduction_type": "element",
9798
}
98-
}
99+
}

tests/losses/test_smooth_ap_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def test_smooth_ap_loss(self):
187187
labels.extend([i for _ in range(HYPERPARAMETERS["num_id"])])
188188

189189
labels = torch.tensor(labels)
190-
output2 = loss2.compute_loss(
190+
output2 = loss2.forward(
191191
rand_tensor, labels, None, rand_tensor, labels
192-
)["loss"]["losses"]
192+
)
193193
self.assertTrue(torch.isclose(output, output2))

0 commit comments

Comments
 (0)