-
Notifications
You must be signed in to change notification settings - Fork 664
Description
Hi! First of all, thank you so much for the great project!
Description
When training with the CrossBatchMemory
and the TripletMarginMiner
we have faced the following error:
Traceback (most recent call last):
...
File "pytorch-metric-learning/src/pytorch_metric_learning/utils/loss_and_miner_utils.py", line 90, in get_all_triplets_indices
return torch.where(triplets)
RuntimeError: nonzero is not supported for tensors with more than INT_MAX elements, file a support request
It is raised when labels
and ref_labels
are in GPU, and also len(labels) * len(ref_labels) * len(ref_labels) > 2147483647
. This is because the triplets
tensor has len(labels) * len(ref_labels) * len(ref_labels)
elements and the torch.where()
is esentially a .nonzero()
(docs).
As noted in the error message, this is a limitation of PyTorch.
Package versions
- Commit: https://github.com/KevinMusgrave/pytorch-metric-learning/tree/3a14f82f38af31c84c21867c865bbb902bfb1c35
torch==2.1.1+cu118
Code to reproduce
I have written a script to test this issue and also reimplemented the get_all_triplets_indices
to work around the problem. I have also found that the new function is actually faster when len(ref_labels) >> len(labels)
.
from pytorch_metric_learning.utils.loss_and_miner_utils import get_all_triplets_indices, get_matches_and_diffs
import torch
import timeit
def get_all_triplets_indices_new(labels, ref_labels=None):
all_matches, all_diffs = get_matches_and_diffs(labels, ref_labels)
all_matches, all_diffs = all_matches.bool(), all_diffs.bool()
# Find anchors with at least a positive and a negative
indices = torch.arange(0, len(labels), device=labels.device)
indices = indices[all_matches.any(dim=1) & all_diffs.any(dim=1)]
# No triplets found
if len(indices) == 0:
return (torch.tensor([], device=labels.device, dtype=labels.dtype),
torch.tensor([], device=labels.device, dtype=labels.dtype),
torch.tensor([], device=labels.device, dtype=labels.dtype))
# Compute all triplets
anchors = []
positives = []
negatives = []
for i in indices:
matches = all_matches[i].nonzero(as_tuple=False).squeeze(1)
diffs = all_diffs[i].nonzero(as_tuple=False).squeeze(1)
nd = len(diffs)
nm = len(matches)
matches = matches.repeat_interleave(nd)
diffs = diffs.repeat(nm)
anchors.append(torch.full((len(matches),), i, dtype=labels.dtype, device=labels.device))
positives.append(matches)
negatives.append(diffs)
return torch.cat(anchors), torch.cat(positives), torch.cat(negatives)
def main():
BATCH_SIZE = 21 # with 22 breaks
MEMORY_SIZE = 10000
NCLASSES = 1000
a = torch.randint(0, NCLASSES, (BATCH_SIZE,))
b = torch.randint(0, NCLASSES, (MEMORY_SIZE,))
# b = None
a = a.cuda()
b = b.cuda()
if BATCH_SIZE * MEMORY_SIZE * MEMORY_SIZE > torch.iinfo(torch.int32).max:
print(f"Code will break ({BATCH_SIZE * MEMORY_SIZE * MEMORY_SIZE} > {torch.iinfo(torch.int32).max}).")
y = get_all_triplets_indices_new(labels=a,
ref_labels=b)
x = get_all_triplets_indices(labels=a,
ref_labels=b)
assert (y[0] == x[0]).all() and (y[1] == x[1]).all() and (y[2] == x[2]).all()
# Benchmark with timeit
N = 50
t = timeit.timeit(lambda: get_all_triplets_indices_new(labels=a,
ref_labels=b),
number=N)
print(f"New function: {t/N*1000:.2f} ms")
t = timeit.timeit(lambda: get_all_triplets_indices(labels=a,
ref_labels=b),
number=N)
print(f"Original function: {t/N*1000:.2f} ms")
if __name__ == '__main__':
main()
Some results
Improved performance
- Batch size = 21, Memory size = 10000: New function: 4.73 ms, Original function: 120.99 ms
- Batch size = 512, Memory size = 2000: New function: 92.61 ms, Original function: 116.38 ms
Worse performance
The new function could reduce performance on trainings with large batch sizes and no CrossBatchMemory (i.e. mining triplets inside a large batch). However, the new function might still be fast enough for most cases.
- Batch size = 1000, Memory size = 1000: New function: 130.77 ms, Original function: 56.52 ms
- Batch size = 256, Memory size = 256: New function: 9.77 ms, Original function: 0.98 ms