Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,6 @@ runs
*.pth

*zarr/*
issue38366/

issue8366/
68 changes: 67 additions & 1 deletion monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,73 @@
tqdm, _ = optional_import("tqdm", name="tqdm")
_nearest_mode = "nearest-exact"

__all__ = ["sliding_window_inference"]
__all__ = ["ensure_channel_first","sliding_window_inference"]

def ensure_channel_first(
x: torch.Tensor,
spatial_ndim: Optional[int] = None,
channel_hint: Optional[int] = None,
threshold: int = 32,
) -> tuple[torch.Tensor, int]:
"""
Normalize a tensor to channel-first layout (N, C, spatial...).

Args:
x: Tensor with shape (N, C, spatial...) or (N, spatial..., C).
spatial_ndim: Number of spatial dimensions. If None, inferred as x.ndim - 2.
channel_hint: If provided, the expected channel size (e.g., num_classes). When present,
we prioritize matching this size at either dim=1 (channel-first) or dim=-1 (channel-last).
threshold: Heuristic upper bound for typical channel counts to disambiguate layouts.

Returns:
A tuple (x_cf, orig_channel_dim):
- x_cf: the tensor in channel-first layout.
- orig_channel_dim: 1 if input was already channel-first; -1 if the channel was last.

Raises:
TypeError: if x is not a torch.Tensor.
ValueError: if x.ndim < 3 or the channel dimension cannot be inferred unambiguously.

Notes:
1. When channel_hint is provided, it is used as a strong signal to decide layout.
2. Otherwise, uses a heuristic where channels are usually small (<= threshold).
3. In ambiguous cases (both candidate dims small or both large), the input layout
is preserved (assumed channel-first) to avoid silent mis-reordering.
"""
if not isinstance(x, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(x)}")
if x.ndim < 3:
raise ValueError(f"Expected >=3 dims (N,C,spatial...), got shape={tuple(x.shape)}")

if spatial_ndim is None:
spatial_ndim = x.ndim - 2 # informative only

s1 = int(x.shape[1]) # candidate channel at dim=1
sl = int(x.shape[-1]) # candidate channel at dim=-1

# 1) Strong signal: use channel_hint if provided
if channel_hint is not None:
if s1 == channel_hint and sl != channel_hint:
return x, 1
if sl == channel_hint and s1 != channel_hint:
return x.movedim(-1, 1), -1
# if both match or both mismatch, fall back to heuristic

# 2) Heuristic: channels are usually small
if s1 <= threshold and sl > threshold:
return x, 1
if sl <= threshold and s1 > threshold:
return x.movedim(-1, 1), -1

# 3) Ambiguous: both sides small OR both sides large → preserve as channel-first
if (s1 <= threshold and sl <= threshold) or (s1 > threshold and sl > threshold):
return x, 1

# 4) Should not reach here under normal cases
raise ValueError(
f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]"
)



def sliding_window_inference(
Expand Down
25 changes: 25 additions & 0 deletions monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from monai.metrics.utils import do_metric_reduction
from monai.utils import MetricReduction, deprecated_arg
from monai.inferers.utils import ensure_channel_first

Comment on lines +18 to +19
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Validate import location and layering; avoid metrics → inferers dependency if possible

Importing ensure_channel_first from monai.inferers.utils couples metrics to inferers. Confirm final placement (inferers.utils vs metrics.utils as mentioned in the PR) and ensure no circulars. If this is a generic tensor-layout utility, consider relocating to a neutral utils module to reduce layering risk.

Run to verify single source and usages:


🏁 Script executed:

#!/bin/bash
set -euo pipefail
rg -nP 'def\s+ensure_channel_first\b' -g 'monai/**/*.py'
rg -nP '^from\s+monai\.(?:inferers|metrics)\.utils\s+import\s+ensure_channel_first\b' -g 'monai/**/*.py' -S -C1

Length of output: 443


🏁 Script executed:

#!/bin/bash
set -euo pipefail
if [[ -f monai/metrics/utils.py ]]; then
  echo "metrics/utils.py exists"
else
  echo "metrics/utils.py not found"
fi
if [[ -f monai/utils.py ]]; then
  echo "monai/utils.py exists"
else
  echo "monai/utils.py not found"
fi

Length of output: 195


🏁 Script executed:

#!/bin/bash
set -euo pipefail
if [[ -d monai/utils ]]; then
  echo "monai/utils directory exists"
  ls monai/utils
else
  echo "monai/utils directory not found"
fi

Length of output: 318


Relocate ensure_channel_first to a neutral utils module

Avoid coupling metrics to inferers by keeping generic tensor-layout helpers out of monai.inferers. Since ensure_channel_first is a generic utility, move it to an existing or new module under monai/utils (e.g. monai/utils/misc.py or monai/utils/tensor_utils.py), then update imports accordingly. This reduces layering risk and prevents potential circular dependencies.

• File: monai/metrics/meandice.py (line 18)
• New location: monai/utils/(misc.py or tensor_utils.py)

Example diff:

-from monai.inferers.utils import ensure_channel_first
+from monai.utils.misc import ensure_channel_first

• After moving, remove the old definition in monai/inferers/utils.py or re-export it there if needed by inferers.
• Verify no import cycles between metrics and inferers after the change.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In monai/metrics/meandice.py around lines 18-19, the import from
monai.inferers.utils should be replaced because ensure_channel_first is a
generic tensor-layout helper; move its implementation from
monai/inferers/utils.py into a neutral utils module under monai/utils (e.g.,
monai/utils/misc.py or monai/utils/tensor_utils.py), update meandice.py to
import ensure_channel_first from the new module, and either remove the old
implementation from monai/inferers/utils.py or add a thin re-export there to
preserve backward compatibility; after changes, run import graph checks and
tests to ensure no import cycles between metrics and inferers.


from .metric import CumulativeIterationMetric

Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(
num_classes=self.num_classes,
)


def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""
Compute the dice value using ``DiceHelper``.
Expand Down Expand Up @@ -306,6 +309,28 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
"""


# --- Normalize layout to channel-first (N, C, spatial...) ---
# Prefer a strong signal when available.
if self.num_classes is not None:
y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes)
else:
# First pass: heuristic only.
y_pred, _ = ensure_channel_first(y_pred)
# Fallback: if implausible vs y's layout, retry with a hint from y's last dim.
if y.ndim == y_pred.ndim:
plausible = {1, y.shape[1], y.shape[-1]}
if y_pred.shape[1] not in plausible:
y_pred, _ = ensure_channel_first(y_pred, channel_hint=int(y.shape[-1]))

# Infer channels after normalization (or use provided).
n_ch = self.num_classes or y_pred.shape[1]

# Normalize y if it plausibly is channel-last.
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
y, _ = ensure_channel_first(y, channel_hint=n_ch)

Comment on lines +314 to +333
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Fix NHWC/NCHW ambiguity; fail fast when num_classes is missing; strengthen fallback

As written, ambiguous shapes can slip through (e.g., NHWC where H==W==C>32); the heuristic may preserve NHWC and silently misinterpret channels. Add an explicit ambiguity check when num_classes is None, and only retry with a y-informed hint when plausible. Also error on the classic “y_pred is label map (C==1) but y is one-hot CL with C>1” unless num_classes is provided.

         # --- Normalize layout to channel-first (N, C, spatial...) ---
         # Prefer a strong signal when available.
-        if self.num_classes is not None:
-                y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes)
-        else:
-            # First pass: heuristic only.
-            y_pred, _ = ensure_channel_first(y_pred)
-            # Fallback: if implausible vs y's layout, retry with a hint from y's last dim.
-            if y.ndim == y_pred.ndim:
-                plausible = {1, y.shape[1], y.shape[-1]}
-                if y_pred.shape[1] not in plausible:
-                    y_pred, _ = ensure_channel_first(y_pred, channel_hint=int(y.shape[-1]))
+        if self.num_classes is not None:
+            y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes)
+        else:
+            # Detect clearly ambiguous layouts (dim1 == dim-1 > 1) and require num_classes.
+            if y.ndim == y_pred.ndim and y.shape[1] == y.shape[-1] and y.shape[1] > 1:
+                raise ValueError(
+                    "Ambiguous channel layout: y has identical sizes at dim=1 and dim=-1; "
+                    "please provide `num_classes` to disambiguate."
+                )
+            # First pass: heuristic only.
+            y_pred, _ = ensure_channel_first(y_pred)
+            # Fallback: if implausible vs y's layout, retry with a hint from y (prefer last dim).
+            if y.ndim == y_pred.ndim:
+                plausible = {1, y.shape[1], y.shape[-1]}
+                if y_pred.shape[1] not in plausible:
+                    hint = int(y.shape[-1]) if y.shape[-1] in plausible else int(y.shape[1])
+                    y_pred, _ = ensure_channel_first(y_pred, channel_hint=hint)
 
         # Infer channels after normalization (or use provided).
         n_ch = self.num_classes or y_pred.shape[1]
 
         # Normalize y if it plausibly is channel-last.
-        if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
+        if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
             y, _ = ensure_channel_first(y, channel_hint=n_ch)
+        # Inconsistent pair: label-map y_pred (C==1) vs one-hot y (last dim > 1) without num_classes.
+        if (
+            self.num_classes is None
+            and y.ndim == y_pred.ndim
+            and y_pred.shape[1] == 1
+            and y.shape[-1] > 1
+            and y.shape[1] == 1
+        ):
+            raise ValueError(
+                "Inconsistent inputs: y_pred appears to be a label map (C==1) while y is one-hot "
+                "channel-last (last dim > 1) but `num_classes` is not provided. "
+                "Specify `num_classes` to compute per-class Dice."
+            )

Follow-ups:

  • Add tests for:
    • NHWC with H==W==C (e.g., 64) and num_classes=None → expect ValueError.
    • NHWC with large C and num_classes=C → matches NCHW result.
    • y_pred: [N,1,H,W], y: [N,H,W,C], num_classes=C → matches NCHW result; and same case with num_classes=None → expect ValueError.

I can draft the unit tests if you want them in tests/metrics/test_dice_layouts.py.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# --- Normalize layout to channel-first (N, C, spatial...) ---
# Prefer a strong signal when available.
if self.num_classes is not None:
y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes)
else:
# First pass: heuristic only.
y_pred, _ = ensure_channel_first(y_pred)
# Fallback: if implausible vs y's layout, retry with a hint from y's last dim.
if y.ndim == y_pred.ndim:
plausible = {1, y.shape[1], y.shape[-1]}
if y_pred.shape[1] not in plausible:
y_pred, _ = ensure_channel_first(y_pred, channel_hint=int(y.shape[-1]))
# Infer channels after normalization (or use provided).
n_ch = self.num_classes or y_pred.shape[1]
# Normalize y if it plausibly is channel-last.
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
y, _ = ensure_channel_first(y, channel_hint=n_ch)
# --- Normalize layout to channel-first (N, C, spatial...) ---
# Prefer a strong signal when available.
if self.num_classes is not None:
y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes)
else:
# Detect clearly ambiguous layouts (dim1 == dim-1 > 1) and require num_classes.
if y.ndim == y_pred.ndim and y.shape[1] == y.shape[-1] and y.shape[1] > 1:
raise ValueError(
"Ambiguous channel layout: y has identical sizes at dim=1 and dim=-1; "
"please provide `num_classes` to disambiguate."
)
# First pass: heuristic only.
y_pred, _ = ensure_channel_first(y_pred)
# Fallback: if implausible vs y's layout, retry with a hint from y (prefer last dim).
if y.ndim == y_pred.ndim:
plausible = {1, y.shape[1], y.shape[-1]}
if y_pred.shape[1] not in plausible:
hint = int(y.shape[-1]) if y.shape[-1] in plausible else int(y.shape[1])
y_pred, _ = ensure_channel_first(y_pred, channel_hint=hint)
# Infer channels after normalization (or use provided).
n_ch = self.num_classes or y_pred.shape[1]
# Normalize y if it plausibly is channel-last.
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
y, _ = ensure_channel_first(y, channel_hint=n_ch)
# Inconsistent pair: label-map y_pred (C==1) vs one-hot y (last dim > 1) without num_classes.
if (
self.num_classes is None
and y.ndim == y_pred.ndim
and y_pred.shape[1] == 1
and y.shape[-1] > 1
and y.shape[1] == 1
):
raise ValueError(
"Inconsistent inputs: y_pred appears to be a label map (C==1) while y is one-hot "
"channel-last (last dim > 1) but `num_classes` is not provided. "
"Specify `num_classes` to compute per-class Dice."
)
🤖 Prompt for AI Agents
In monai/metrics/meandice.py around lines 314 to 333, the current heuristic for
normalizing channel layout can silently accept ambiguous NHWC shapes (e.g.
H==W==C) and misinterpret channels when num_classes is None; update the logic to
detect and fail-fast on ambiguous layouts when num_classes is not provided, only
perform the y-informed retry when the hint is plausible (e.g., y.shape[-1] in
{1, y_pred.shape[-1], expected range}) and otherwise raise a ValueError
describing the ambiguity; additionally, add an explicit check that rejects the
case where y_pred is a label map (channels==1) while y is clearly one-hot
(channels>1) unless num_classes is supplied, and use num_classes to disambiguate
or convert confidently when available.

_apply_argmax, _threshold = self.apply_argmax, self.threshold
if self.num_classes is None:
n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores
Expand Down
Loading