Skip to content

Allow scaling up the memory and batch size when using TripletMarginMiner and CrossBatchMemory #688

@mkmenta

Description

@mkmenta

Hi! First of all, thank you so much for the great project!

Description

When training with the CrossBatchMemory and the TripletMarginMiner we have faced the following error:

Traceback (most recent call last):
  ...
 File "pytorch-metric-learning/src/pytorch_metric_learning/utils/loss_and_miner_utils.py", line 90, in get_all_triplets_indices
    return torch.where(triplets)
RuntimeError: nonzero is not supported for tensors with more than INT_MAX elements,   file a support request

It is raised when labels and ref_labels are in GPU, and also len(labels) * len(ref_labels) * len(ref_labels) > 2147483647. This is because the triplets tensor has len(labels) * len(ref_labels) * len(ref_labels) elements and the torch.where() is esentially a .nonzero() (docs).

As noted in the error message, this is a limitation of PyTorch.

Package versions

Code to reproduce

I have written a script to test this issue and also reimplemented the get_all_triplets_indices to work around the problem. I have also found that the new function is actually faster when len(ref_labels) >> len(labels).

from pytorch_metric_learning.utils.loss_and_miner_utils import get_all_triplets_indices, get_matches_and_diffs
import torch
import timeit


def get_all_triplets_indices_new(labels, ref_labels=None):
    all_matches, all_diffs = get_matches_and_diffs(labels, ref_labels)
    all_matches, all_diffs = all_matches.bool(), all_diffs.bool()

    # Find anchors with at least a positive and a negative
    indices = torch.arange(0, len(labels), device=labels.device)
    indices = indices[all_matches.any(dim=1) & all_diffs.any(dim=1)]

    # No triplets found
    if len(indices) == 0:
        return (torch.tensor([], device=labels.device, dtype=labels.dtype),
                torch.tensor([], device=labels.device, dtype=labels.dtype),
                torch.tensor([], device=labels.device, dtype=labels.dtype))

    # Compute all triplets
    anchors = []
    positives = []
    negatives = []
    for i in indices:
        matches = all_matches[i].nonzero(as_tuple=False).squeeze(1)
        diffs = all_diffs[i].nonzero(as_tuple=False).squeeze(1)
        nd = len(diffs)
        nm = len(matches)
        matches = matches.repeat_interleave(nd)
        diffs = diffs.repeat(nm)
        anchors.append(torch.full((len(matches),), i, dtype=labels.dtype, device=labels.device))
        positives.append(matches)
        negatives.append(diffs)
    return torch.cat(anchors), torch.cat(positives), torch.cat(negatives)


def main():
    BATCH_SIZE = 21  # with 22 breaks
    MEMORY_SIZE = 10000
    NCLASSES = 1000
    a = torch.randint(0, NCLASSES, (BATCH_SIZE,))
    b = torch.randint(0, NCLASSES, (MEMORY_SIZE,))
    # b = None
    a = a.cuda()
    b = b.cuda()
    if BATCH_SIZE * MEMORY_SIZE * MEMORY_SIZE > torch.iinfo(torch.int32).max:
        print(f"Code will break ({BATCH_SIZE * MEMORY_SIZE * MEMORY_SIZE} > {torch.iinfo(torch.int32).max}).")
    y = get_all_triplets_indices_new(labels=a,
                                     ref_labels=b)
    x = get_all_triplets_indices(labels=a,
                                 ref_labels=b)
    assert (y[0] == x[0]).all() and (y[1] == x[1]).all() and (y[2] == x[2]).all()

    # Benchmark with timeit
    N = 50
    t = timeit.timeit(lambda: get_all_triplets_indices_new(labels=a,
                                                           ref_labels=b),
                      number=N)
    print(f"New function: {t/N*1000:.2f} ms")

    t = timeit.timeit(lambda: get_all_triplets_indices(labels=a,
                                                       ref_labels=b),
                      number=N)
    print(f"Original function: {t/N*1000:.2f} ms")


if __name__ == '__main__':
    main()

Some results

Improved performance

  • Batch size = 21, Memory size = 10000: New function: 4.73 ms, Original function: 120.99 ms
  • Batch size = 512, Memory size = 2000: New function: 92.61 ms, Original function: 116.38 ms

Worse performance
The new function could reduce performance on trainings with large batch sizes and no CrossBatchMemory (i.e. mining triplets inside a large batch). However, the new function might still be fast enough for most cases.

  • Batch size = 1000, Memory size = 1000: New function: 130.77 ms, Original function: 56.52 ms
  • Batch size = 256, Memory size = 256: New function: 9.77 ms, Original function: 0.98 ms

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions