Skip to content

enhance: avoid unnecessary fallback in _bincount on CUDA with deterministic #3087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed `_bincount` being less restrictive ([#3087](https://github.com/Lightning-AI/torchmetrics/pull/3087))


---
Expand Down
14 changes: 8 additions & 6 deletions src/torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor

from torchmetrics.utilities.exceptions import TorchMetricsUserWarning
from torchmetrics.utilities.imports import _TORCH_LESS_THAN_2_6, _XLA_AVAILABLE
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_LESS_THAN_2_6, _XLA_AVAILABLE
from torchmetrics.utilities.prints import rank_zero_warn

METRIC_EPS = 1e-6
Expand Down Expand Up @@ -178,10 +178,12 @@ def _squeeze_if_scalar(data: Any) -> Any:
def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
"""Implement custom bincount.

PyTorch currently does not support ``torch.bincount`` when running in deterministic mode on GPU or when running
MPS devices or when running on XLA device. This implementation therefore falls back to using a combination of
`torch.arange` and `torch.eq` in these scenarios. A small performance hit can expected and higher memory consumption
as `[batch_size, mincount]` tensor needs to be initialized compared to native ``torch.bincount``.
As of PyTorch v2.1, ``torch.bincount`` is supported in deterministic mode on CUDA
when no ``weights`` are provided and gradients are not required. However, this
operation remains unsupported or limited on some backends, such as MPS and XLA.
In those cases, we fall back to a manual implementation using `torch.arange` and `torch.eq`.
A small performance hit can expected and higher memory consumption as `[batch_size, mincount]`
tensor needs to be initialized compared to native ``torch.bincount``.

Args:
x: tensor to count
Expand All @@ -199,7 +201,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
if minlength is None:
minlength = len(torch.unique(x))

if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or x.is_mps:
if (not _TORCH_GREATER_EQUAL_2_1 and torch.are_deterministic_algorithms_enabled()) or _XLA_AVAILABLE or x.is_mps:
mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1)
return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0)

Expand Down
Loading