Skip to content

Conversation

mattlin1124
Copy link

What does this PR do?
Adds ensure_channel_first utility in monai/metrics/utils.py and integrates it into DiceHelper.call in monai/metrics/meandice.py to normalize input tensors to channel-first layout before Dice computation. This prevents metric calculation errors when the input is provided in channel-last format.

Issue reference
Refs #8366

Motivation and context
Currently, DiceHelper assumes the channel dimension is at index 1. If a user passes channel-last tensors (e.g., [N, H, W, C]), it can lead to incorrect indexing and metric values. This PR ensures that y_pred and (when applicable) y are converted to channel-first format before further processing, preserving original metric behavior while adding channel-last support.

How I tested

Verified that channel-last inputs now produce identical results to channel-first inputs.

Ran pytest -q tests/metrics -k dice locally without failures.

Checked that no changes are introduced to the outputs when inputs are already channel-first.

Copy link

coderabbitai bot commented Aug 9, 2025

Walkthrough

.gitignore now ignores issue38366/ and issue8366/. Added ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int]=None, channel_hint: Optional[int]=None, threshold: int=32) -> Tuple[torch.Tensor, int] in monai/inferers/utils.py to canonicalize tensors to channel-first layout (N, C, spatial...) and return the original channel position (1 if already channel-first, -1 if channel was last). It infers spatial_ndim when omitted, accepts a channel_hint, uses a threshold heuristic to decide layout, reorders with movedim(-1, 1) when needed, and raises on invalid inputs. In monai/metrics/meandice.py a new DiceMetric._compute_tensor(self, y_pred, y) was added; DiceMetric.__call__ and DiceHelper.__call__ now normalize y_pred via ensure_channel_first and conditionally normalize y to channel-first when its trailing dim equals 1 or the number of channels. No public API signatures were removed or renamed.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between cafc1fe and 59786d2.

📒 Files selected for processing (2)
  • .gitignore (1 hunks)
  • monai/metrics/meandice.py (3 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit Configuration File

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/metrics/meandice.py
🔇 Additional comments (3)
.gitignore (1)

168-168: Ignore entry LGTM

Adding issue38366/ to .gitignore is harmless and scoped.

monai/metrics/meandice.py (2)

18-19: Import of ensure_channel_first is correct

Good placement and usage intent.


128-128: No-op whitespace change

Nothing to do here.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🔭 Outside diff range comments (1)
monai/inferers/utils.py (1)

16-16: Fix Ruff F821: import Optional and Tuple.

Type hints use Optional and Tuple but they aren’t imported.

Apply:

-from typing import Any
+from typing import Any, Optional, Tuple
🧹 Nitpick comments (2)
monai/inferers/utils.py (2)

39-39: Consider exporting ensure_channel_first via all.

If this is part of the public surface, add it to all for discoverability and star-import safety.

Add:

__all__ = ["sliding_window_inference", "ensure_channel_first"]

41-73: Add unit tests for layout normalization.

Cover NCHW/NHWC (2D), NCD/ND C (1D), ambiguity/error paths, and round-trip behavior.

Suggested cases:

  • (N, C=1/2/3/4, H, W) → unchanged; orig=1.
  • (N, H, W, C=1/2/3/4) → movedim to (N, C, H, W); orig=-1.
  • (N, C, L) and (N, L, C) → both paths.
  • ndim < 3 → ValueError.
  • Ambiguous both-large dims → ValueError.
    I can draft tests in monai/tests/test_ensure_channel_first.py if helpful.
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 59786d2 and c2b5fc4.

📒 Files selected for processing (2)
  • monai/inferers/utils.py (1 hunks)
  • monai/metrics/meandice.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • monai/metrics/meandice.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit Configuration File

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/inferers/utils.py
🪛 Ruff (0.12.2)
monai/inferers/utils.py

41-41: Undefined name Optional

(F821)


41-41: Undefined name Tuple

(F821)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (mypy)

@@ -38,7 +38,39 @@

__all__ = ["sliding_window_inference"]

def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) -> Tuple[torch.Tensor, int]:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify import path used by DiceHelper.

PR text says the function was added to monai/metrics/utils.py, but it’s implemented here (monai/inferers/utils.py). Confirm call sites import the correct module path, or move the helper to the intended module.

Run:


🏁 Script executed:

#!/bin/bash
set -euo pipefail
echo "Usages of ensure_channel_first:"
rg -n --hidden -S '\bensure_channel_first\b' || true
echo
echo "Imports of ensure_channel_first:"
rg -n --hidden -S 'from monai\.(metrics|inferers)\.utils import ensure_channel_first' || true
echo
echo "Definition locations:"
ast-grep --pattern $'def ensure_channel_first($_):\n  $$$'

Length of output: 3584


Ensure correct import path for ensure_channel_first
The helper lives in monai/inferers/utils.py, but in monai/metrics/meandice.py (line 18) it’s imported from monai.metrics.utils, which doesn’t define it.

• Either move ensure_channel_first into monai/metrics/utils.py and keep the existing import
• Or update the import in monai/metrics/meandice.py to:
from monai.inferers.utils import ensure_channel_first

🧰 Tools
🪛 Ruff (0.12.2)

41-41: Undefined name Optional

(F821)


41-41: Undefined name Tuple

(F821)

🤖 Prompt for AI Agents
In monai/metrics/meandice.py at line 18, the import statement for
ensure_channel_first is incorrect because it imports from monai.metrics.utils
where the function is not defined. To fix this, update the import statement to
import ensure_channel_first from monai.inferers.utils instead, i.e., change it
to 'from monai.inferers.utils import ensure_channel_first'. Alternatively, if
preferred, move the ensure_channel_first function from monai/inferers/utils.py
to monai/metrics/utils.py and keep the existing import unchanged.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
monai/inferers/utils.py (1)

41-41: Fix Ruff F821: use PEP 604/PEP 585 types (no need to import Optional/Tuple).

Ruff flags undefined names Optional/Tuple. Prefer modern annotations to avoid adding imports.

Apply:

-def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) -> Tuple[torch.Tensor, int]:
+def ensure_channel_first(x: torch.Tensor, spatial_ndim: int | None = None) -> tuple[torch.Tensor, int]:
🧹 Nitpick comments (5)
monai/inferers/utils.py (3)

58-61: Docstring tweak: align “ambiguous” behavior with implementation.

After the fix above, ambiguous (both small or both large) always preserves layout. Update the Notes to reflect that we preserve rather than raise.

Proposed wording:

  • “Uses a small-channel heuristic (<=32). In ambiguous cases (both candidate dims small or both large), preserves input layout to avoid silent mis-reordering.”

67-70: Minor: clarify intentionally unused spatial_ndim.

The inferred spatial_ndim isn’t used. Either drop the block or rename the local to underscore to signal intent.

Apply:

-    if spatial_ndim is None:
-        spatial_ndim = x.ndim - 2  # not directly used for logic; informative only
+    if spatial_ndim is None:
+        _spatial_ndim = x.ndim - 2  # informational only; not used in logic

39-39: Export ensure_channel_first in all for public use.

You’re using this utility across modules; consider exposing it via all to mark it public.

Outside this hunk, update:

__all__ = ["sliding_window_inference", "ensure_channel_first"]
monai/metrics/meandice.py (2)

128-145: _compute_tensor addition LGTM; tiny style nit.

Validation and delegation are correct. Prefer y_pred.ndim over ndimension().

Apply:

-        dims = y_pred.ndimension()
+        dims = y_pred.ndim

304-311: Docstring: explicitly state channel-last is accepted and auto-normalized.

Reflect the new behavior for clarity.

Apply:

-        Args:
-            y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
-                the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
-            y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
+        Args:
+            y_pred: predictions in channel-first or channel-last layout:
+                - (batch_size, num_classes or 1, spatial_dims...)  [channel-first]
+                - (batch_size, spatial_dims..., num_classes or 1)  [channel-last]
+                The number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes`` is None.
+            y: ground truth in channel-first or channel-last layout; if trailing dim is 1 or matches
+                the number of classes, it will be normalized to channel-first automatically.
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between c2b5fc4 and 77d09cf.

📒 Files selected for processing (2)
  • monai/inferers/utils.py (1 hunks)
  • monai/metrics/meandice.py (3 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit Configuration File

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/metrics/meandice.py
  • monai/inferers/utils.py
🧬 Code Graph Analysis (1)
monai/metrics/meandice.py (1)
monai/inferers/utils.py (1)
  • ensure_channel_first (41-91)
🪛 Ruff (0.12.2)
monai/inferers/utils.py

41-41: Undefined name Optional

(F821)


41-41: Undefined name Tuple

(F821)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: packaging
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (macOS-latest)
🔇 Additional comments (1)
monai/metrics/meandice.py (1)

18-19: Correct import path; good catch.

Importing ensure_channel_first from monai.inferers.utils matches the implementation location.

Comment on lines 84 to 107
# Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
if s1 <= threshold and sl <= threshold:
return x, 1

raise ValueError(
f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; "
f"both dim1={s1} and dim-1={sl} look like spatial dims"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: both-large case wrongly raises; breaks valid NCHW with many classes (C > 32).

For y_pred in channel-first with many classes (e.g., N, 64, H, W), s1 > 32 and sl (W) > 32; this path raises ValueError and will break Dice. Prefer preserving layout in both-large cases (consistent with the comment on Line 84).

Apply:

-    # Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
-    if s1 <= threshold and sl <= threshold:
-        return x, 1
-
-    raise ValueError(
-        f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; "
-        f"both dim1={s1} and dim-1={sl} look like spatial dims"
-    )
+    # Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
+    if (s1 <= threshold and sl <= threshold) or (s1 > threshold and sl > threshold):
+        return x, 1
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
if s1 <= threshold and sl <= threshold:
return x, 1
raise ValueError(
f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; "
f"both dim1={s1} and dim-1={sl} look like spatial dims"
)
# Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
if (s1 <= threshold and sl <= threshold) or (s1 > threshold and sl > threshold):
return x, 1
🤖 Prompt for AI Agents
In monai/inferers/utils.py around lines 84 to 91, the code currently raises a
ValueError when both dim1 and dim-1 are large (e.g., C>32), which incorrectly
rejects valid channel-first tensors; change the branching so that when both s1
and sl are small OR both are large the function preserves the existing layout
(return x, 1) instead of raising; only raise when sizes are truly ambiguous
(e.g., one small and one small-ish conflict) or adjust the thresholds/conditions
accordingly and update the error message to reflect remaining ambiguous cases.

Comment on lines 312 to 322
y_pred, _ = ensure_channel_first(y_pred)

n_ch = self.num_classes or y_pred.shape[1]
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
y, _ = ensure_channel_first(y)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Channel-last normalization is correctly integrated; ensure upstream heuristic doesn’t blow up.

This is the right place to normalize y_pred and y. However, with current ensure_channel_first, channel-first tensors with many classes (C > 32) will raise unless the utils fix is applied. Once that’s in, this path is solid.

To be safe, verify both channel-first and channel-last with larger class counts:

Expected: both print tensor shapes without exceptions.


🏁 Script executed:

#!/bin/bash
python - <<'PY'
import torch
from monai.metrics.meandice import DiceHelper

def run(n=2, c=64, h=64, w=64):
    y_pred_cf = torch.rand(n, c, h, w)
    y_cf = torch.randint(0, c, (n, 1, h, w))
    print("CF/C>32:", DiceHelper(num_classes=c)(y_pred_cf, y_cf).shape)

    y_pred_cl = torch.rand(n, h, w, c)
    y_cl = torch.randint(0, c, (n, h, w, 1))
    print("CL/C>32:", DiceHelper(num_classes=c)(y_pred_cl, y_cl).shape)

run()
PY

Length of output: 128


🏁 Script executed:

#!/bin/bash
set -eo pipefail

# Locate ensure_channel_first definition
rg -n "def ensure_channel_first" -g "*.py"

Length of output: 197


🏁 Script executed:

#!/bin/bash
set -eo pipefail

# Show ensure_channel_first implementation
sed -n '1,200p' monai/inferers/utils.py

Length of output: 10537


Support >32 channels in ensure_channel_first
ensure_channel_first in monai/inferers/utils.py hard-codes a 32-channel threshold and will throw on both NCHW and NHWC inputs when C>32. Before merging, please:

  • Update ensure_channel_first to accept a configurable threshold (e.g., use num_classes or expose a kwarg) so it no longer raises for large class counts.
  • Or catch its ValueError in DiceHelper and default to treating dim-1 as the channel axis.
  • Add unit tests for C>32 on both channel-first and channel-last inputs to verify no exceptions.
🤖 Prompt for AI Agents
In monai/metrics/meandice.py around lines 312 to 317, ensure_channel_first can
raise for channel counts >32; modify the code to catch ValueError from
ensure_channel_first and on exception treat dim-1 as the channel axis (i.e.,
transpose NHWC to NCHW manually or accept input as already channel-first),
ensuring n_ch is computed from num_classes or y_pred.shape[1] accordingly;
alternatively, update monai/inferers/utils.py:ensure_channel_first to accept a
configurable threshold or a kwarg (e.g., max_channel_threshold or
force_last_dim_channel) and use that in this call so it won’t raise for large C;
add unit tests that cover C>32 for both channel-first (NCHW) and channel-last
(NHWC) inputs to verify no exceptions and correct channel handling.

林旻佑 and others added 12 commits August 24, 2025 16:14
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
monai/inferers/utils.py (1)

86-103: Ambiguous both-large case needs a call-site hint; document the contract.

When both dim1 and dim-1 are > threshold and no channel_hint is given, NHWC with large C will be preserved as-is (assumed NCHW). That’s fine if call sites always supply channel_hint when num_classes is known. Please state this expectation explicitly in the docstring (Notes) to avoid silent mis-ordering in downstream uses that don’t pass a hint.

monai/metrics/meandice.py (1)

312-321: Optional: early ambiguity check for y.

If y has neither y.shape[1] nor y.shape[-1] in {1, n_ch}, consider raising a clear error to surface layout issues sooner.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 77d09cf and 51eeb65.

📒 Files selected for processing (3)
  • .gitignore (1 hunks)
  • monai/inferers/utils.py (1 hunks)
  • monai/metrics/meandice.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • .gitignore
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/metrics/meandice.py
  • monai/inferers/utils.py
🪛 Ruff (0.12.2)
monai/inferers/utils.py

41-41: SyntaxError: Simple statements must be separated by newlines or semicolons


41-41: SyntaxError: Simple statements must be separated by newlines or semicolons

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: packaging
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (macOS-latest)
🔇 Additional comments (4)
monai/inferers/utils.py (1)

44-74: API and docstring look solid.

Signature, types, and Google-style docstring meet guidelines. Nice.

monai/metrics/meandice.py (3)

18-19: Correct import path.

Importing from monai.inferers.utils matches the actual implementation location. Good.


128-145: _compute_tensor addition LGTM.

Clear validation and delegation to DiceHelper; aligns with CumulativeIterationMetric design.


18-19: Import check passed; manual NHWC/NCHW validation required

  • Only ensure_channel_first import found at monai/metrics/meandice.py:18
  • No stale imports from monai.metrics.utils detected
  • Sandbox lacks PyTorch—please re-run the NHWC/NCHW smoke-test snippet locally (also for the block at lines 312–321)

Comment on lines 41 to 43
dfrom typing import Optional, Tuple
import torch

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix syntax error and remove duplicate import; also export the new symbol.

  • Typo: dfromfrom causes a SyntaxError.
  • import torch is duplicated (already imported on Line 19).
  • Consider exporting ensure_channel_first via __all__.

Apply:

- dfrom typing import Optional, Tuple
-import torch
+from typing import Optional, Tuple

And update the export list (outside the changed hunk, around Line 39):

-__all__ = ["sliding_window_inference"]
+__all__ = ["ensure_channel_first", "sliding_window_inference"]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dfrom typing import Optional, Tuple
import torch
# Around line 39: update exports
__all__ = ["ensure_channel_first", "sliding_window_inference"]
# Around lines 41–43: fix typo and drop duplicate import
from typing import Optional, Tuple
🧰 Tools
🪛 Ruff (0.12.2)

41-41: SyntaxError: Simple statements must be separated by newlines or semicolons


41-41: SyntaxError: Simple statements must be separated by newlines or semicolons

🤖 Prompt for AI Agents
In monai/inferers/utils.py around lines 41 to 43, fix the syntax error by
changing the typo "dfrom" to "from", remove the duplicate "import torch" since
torch is already imported earlier (line 19), and update the module export list
(__all__ around line 39) to include the new symbol "ensure_channel_first" so it
is exported.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (5)
monai/inferers/utils.py (5)

39-39: Export the new symbol.

Expose ensure_channel_first via __all__ so it’s discoverable and consistent with other utilities.

-__all__ = ["sliding_window_inference"]
+__all__ = ["ensure_channel_first", "sliding_window_inference"]

47-71: Clarify docstring: “orig_channel_dim” is an inference, not guaranteed original.

In ambiguous cases you return (x, 1) even if the input was channel-last. Make the doc explicit to avoid misuse.

-        - orig_channel_dim: 1 if input was already channel-first; -1 if the channel was last.
+        - orig_channel_dim: 1 when the function infers/preserves channel-first; -1 if the channel was last
+          and a reorder was performed. In ambiguous cases, the layout is preserved and this value will be 1.

Optionally add:

     Notes:
-        1. When channel_hint is provided, it is used as a strong signal to decide layout.
+        1. When channel_hint is provided, it is used as a strong signal to decide layout. Supplying it
+           is recommended to resolve small-volume ambiguities.

77-79: Drop unused assignment or use the value.

spatial_ndim is computed then unused. Either remove the block or incorporate it into error messages. Keeping dead code invites confusion.

-    if spatial_ndim is None:
-        spatial_ndim = x.ndim - 2  # informative only
+    # spatial_ndim is optional; current logic does not require it.

83-90: Channel hint tie-breaker policy (optional).

When s1 == sl == channel_hint, you fall back to the size heuristic. Consider explicitly preserving layout in this tie to make the policy obvious.

-    if channel_hint is not None:
-        if s1 == channel_hint and sl != channel_hint:
+    if channel_hint is not None:
+        if s1 == channel_hint and sl != channel_hint:
             return x, 1
         if sl == channel_hint and s1 != channel_hint:
             return x.movedim(-1, 1), -1
-        # if both match or both mismatch, fall back to heuristic
+        # both match or both mismatch: ambiguous -> preserve
+        # (continue to heuristic if you prefer size-based decision instead)

91-100: ensure_channel_first calls include channel_hint; add tests for ambiguous layouts

All existing invocations of ensure_channel_first (e.g. monai/metrics/meandice.py lines 316 and 320) explicitly pass channel_hint, so no call sites are missing the hint.
However, there are currently no unit tests covering the heuristic’s “ambiguous” case (both spatial dims ≤ threshold or both > threshold), which preserves the input layout by default.

• Recommend adding parametric tests for scenarios such as:

  • NHWC small volumes (e.g. (2,16,16,16,2))
  • NCHW with many channels (e.g. (2,64,32,32))
  • Clear NHWC (e.g. (2,256,256,3))
  • Clear NCHW (e.g. (2,3,256,256))

Happy to help draft these tests.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 51eeb65 and 12a34b7.

📒 Files selected for processing (1)
  • monai/inferers/utils.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/inferers/utils.py
🪛 Ruff (0.12.2)
monai/inferers/utils.py

43-43: Undefined name Optional

(F821)


44-44: Undefined name Optional

(F821)


46-46: Undefined name Tuple

(F821)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: build-docs
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (macOS-latest)

Comment on lines 41 to 46
def ensure_channel_first(
x: torch.Tensor,
spatial_ndim: Optional[int] = None,
channel_hint: Optional[int] = None,
threshold: int = 32,
) -> Tuple[torch.Tensor, int]:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix undefined typing names; use PEP 604 unions for consistency.

Optional/Tuple aren’t imported (Ruff F821) and the rest of the module already uses | and builtin tuple[...]. Update the signature to avoid the lint error and match style.

-def ensure_channel_first(
-    x: torch.Tensor,
-    spatial_ndim: Optional[int] = None,
-    channel_hint: Optional[int] = None,
-    threshold: int = 32,
-) -> Tuple[torch.Tensor, int]:
+def ensure_channel_first(
+    x: torch.Tensor,
+    spatial_ndim: int | None = None,
+    channel_hint: int | None = None,
+    threshold: int = 32,
+) -> tuple[torch.Tensor, int]:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def ensure_channel_first(
x: torch.Tensor,
spatial_ndim: Optional[int] = None,
channel_hint: Optional[int] = None,
threshold: int = 32,
) -> Tuple[torch.Tensor, int]:
def ensure_channel_first(
x: torch.Tensor,
spatial_ndim: int | None = None,
channel_hint: int | None = None,
threshold: int = 32,
) -> tuple[torch.Tensor, int]:
🧰 Tools
🪛 Ruff (0.12.2)

43-43: Undefined name Optional

(F821)


44-44: Undefined name Optional

(F821)


46-46: Undefined name Tuple

(F821)

🤖 Prompt for AI Agents
In monai/inferers/utils.py around lines 41 to 46, the function signature uses
undefined typing names Optional and Tuple and should be converted to PEP 604
style to match the module; change the signature to use spatial_ndim: int | None
and channel_hint: int | None and return type tuple[torch.Tensor, int] (keeping
torch.Tensor for the first arg), and remove any now-unused imports of Optional
and Tuple from typing (or replace them where needed elsewhere) so the lint error
is resolved and style is consistent.

@mattlin1124 mattlin1124 reopened this Aug 24, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
monai/inferers/utils.py (1)

41-46: Fix undefined typing names (Ruff F821).

Optional is not imported; switch to PEP 604 unions to match the module style and avoid the lint error.

Apply:

-def ensure_channel_first(
-    x: torch.Tensor,
-    spatial_ndim: Optional[int] = None,
-    channel_hint: Optional[int] = None,
-    threshold: int = 32,
-) -> tuple[torch.Tensor, int]:
+def ensure_channel_first(
+    x: torch.Tensor,
+    spatial_ndim: int | None = None,
+    channel_hint: int | None = None,
+    threshold: int = 32,
+) -> tuple[torch.Tensor, int]:
🧹 Nitpick comments (3)
monai/inferers/utils.py (2)

77-82: Either use spatial_ndim for validation or drop it.

Currently computed but not used. Consider enforcing consistency when provided.

Apply:

-    if spatial_ndim is None:
-        spatial_ndim = x.ndim - 2  # informative only
+    if spatial_ndim is None:
+        spatial_ndim = x.ndim - 2
+    elif x.ndim != spatial_ndim + 2:
+        raise ValueError(
+            f"Expected x.ndim == spatial_ndim + 2, got x.ndim={x.ndim}, spatial_ndim={spatial_ndim}."
+        )

83-105: Ambiguous layouts can silently mis-orient NHWC with large C; expose an explicit policy.

When both candidate dims are “small” or “large”, you preserve layout. That’s safe but can silently keep NHWC for C>threshold. Offer a controllable policy to let callers fail fast or prefer last.

Apply:

-def ensure_channel_first(
+def ensure_channel_first(
     x: torch.Tensor,
-    spatial_ndim: int | None = None,
-    channel_hint: int | None = None,
-    threshold: int = 32,
+    spatial_ndim: int | None = None,
+    channel_hint: int | None = None,
+    threshold: int = 32,
+    on_ambiguous: str = "preserve",  # {"preserve","raise","prefer_last"}
 ) -> tuple[torch.Tensor, int]:
@@
-    # 3) Ambiguous: both sides small OR both sides large → preserve as channel-first
-    if (s1 <= threshold and sl <= threshold) or (s1 > threshold and sl > threshold):
-        return x, 1
+    # 3) Ambiguous: both sides small OR both sides large
+    if (s1 <= threshold and sl <= threshold) or (s1 > threshold and sl > threshold):
+        if on_ambiguous == "preserve":
+            return x, 1
+        if on_ambiguous == "raise":
+            raise ValueError(
+                f"Ambiguous channel dim for shape={tuple(x.shape)}; set `num_classes` or pass `channel_hint`."
+            )
+        if on_ambiguous == "prefer_last":
+            return x.movedim(-1, 1), -1
+        raise ValueError(f"Invalid on_ambiguous={on_ambiguous!r}")

Follow-up: meandice can pass on_ambiguous="raise" when num_classes is None to avoid silent mis-ordering.

monai/metrics/meandice.py (1)

324-327: Normalize y: add explicit guard for ambiguous y layout.

Current condition handles common last-dim cases. Consider failing fast when neither dim looks like a channel size to avoid silent misuse.

Apply:

         # Normalize y if it plausibly is channel-last.
         if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
             y, _ = ensure_channel_first(y, channel_hint=n_ch)
+        elif y.ndim == y_pred.ndim and y.shape[1] not in (1, n_ch) and y.shape[-1] not in (1, n_ch):
+            raise ValueError(
+                f"Ambiguous y layout for shape={tuple(y.shape)}; expected channel dim size in {{1, {n_ch}}}."
+            )

Optional: Update DiceHelper/DiceMetric docstrings to state NHWC inputs are accepted and auto-normalized.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 12a34b7 and 0fea070.

📒 Files selected for processing (2)
  • monai/inferers/utils.py (1 hunks)
  • monai/metrics/meandice.py (3 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/metrics/meandice.py
  • monai/inferers/utils.py
🪛 Ruff (0.12.2)
monai/inferers/utils.py

43-43: Undefined name Optional

(F821)


44-44: Undefined name Optional

(F821)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (pytype)
🔇 Additional comments (4)
monai/inferers/utils.py (3)

39-40: Good: export the new utility.

Adding ensure_channel_first to all is correct and keeps the symbol discoverable.


47-71: Docstring is clear and aligned with guidelines.

Args/Returns/Raises are explicit; notes are helpful. No action needed.


72-76: Input validation is fine.

Type and rank checks are appropriate for early failure.

monai/metrics/meandice.py (1)

18-19: Correct import path.

Importing ensure_channel_first from monai.inferers.utils matches its actual location.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
monai/metrics/meandice.py (2)

129-145: Add symmetric validation for y dims to fail fast on malformed inputs

You validate y_pred.ndim >= 3 but not y. Add the same check for y to surface bad inputs early (pre-loop).

   def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
@@
-        dims = y_pred.ndimension()
-        if dims < 3:
-            raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.")
+        pred_dims = y_pred.ndimension()
+        if pred_dims < 3:
+            raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {pred_dims}.")
+        y_dims = y.ndimension()
+        if y_dims < 3:
+            raise ValueError(f"y should have at least 3 dimensions (batch, channel, spatial), got {y_dims}.")

307-311: Doc: call now accepts channel-last; clarify ambiguity behavior

The docstring still reads as channel-first only. Note that (N, spatial..., C) is accepted and normalized; also state that ambiguous layouts (e.g., dim1 == dim-1 > 1) require num_classes.

         """
-        Compute the metric for the given prediction and ground truth.
+        Compute the metric for the given prediction and ground truth.
@@
-            y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
-                the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
-            y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
+            y_pred: input predictions with shape either channel-first
+                (batch_size, num_classes or 1, spatial_dims...) or channel-last
+                (batch_size, spatial_dims..., num_classes or 1). Channel-last inputs are
+                auto-normalized to channel-first internally.
+                If ``num_classes`` is ``None``, the number of channels is inferred from
+                the normalized ``y_pred.shape[1]``. For ambiguous layouts where the channel
+                dimension cannot be inferred (e.g., ``y.shape[1] == y.shape[-1] > 1``),
+                provide ``num_classes`` to disambiguate.
+            y: ground truth with shape matching ``y_pred`` in either channel-first or
+                channel-last format; it will be normalized as needed.
         """
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 0fea070 and c52922c.

📒 Files selected for processing (1)
  • monai/metrics/meandice.py (3 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/metrics/meandice.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.6.0)

Comment on lines +18 to +19
from monai.inferers.utils import ensure_channel_first

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Validate import location and layering; avoid metrics → inferers dependency if possible

Importing ensure_channel_first from monai.inferers.utils couples metrics to inferers. Confirm final placement (inferers.utils vs metrics.utils as mentioned in the PR) and ensure no circulars. If this is a generic tensor-layout utility, consider relocating to a neutral utils module to reduce layering risk.

Run to verify single source and usages:


🏁 Script executed:

#!/bin/bash
set -euo pipefail
rg -nP 'def\s+ensure_channel_first\b' -g 'monai/**/*.py'
rg -nP '^from\s+monai\.(?:inferers|metrics)\.utils\s+import\s+ensure_channel_first\b' -g 'monai/**/*.py' -S -C1

Length of output: 443


🏁 Script executed:

#!/bin/bash
set -euo pipefail
if [[ -f monai/metrics/utils.py ]]; then
  echo "metrics/utils.py exists"
else
  echo "metrics/utils.py not found"
fi
if [[ -f monai/utils.py ]]; then
  echo "monai/utils.py exists"
else
  echo "monai/utils.py not found"
fi

Length of output: 195


🏁 Script executed:

#!/bin/bash
set -euo pipefail
if [[ -d monai/utils ]]; then
  echo "monai/utils directory exists"
  ls monai/utils
else
  echo "monai/utils directory not found"
fi

Length of output: 318


Relocate ensure_channel_first to a neutral utils module

Avoid coupling metrics to inferers by keeping generic tensor-layout helpers out of monai.inferers. Since ensure_channel_first is a generic utility, move it to an existing or new module under monai/utils (e.g. monai/utils/misc.py or monai/utils/tensor_utils.py), then update imports accordingly. This reduces layering risk and prevents potential circular dependencies.

• File: monai/metrics/meandice.py (line 18)
• New location: monai/utils/(misc.py or tensor_utils.py)

Example diff:

-from monai.inferers.utils import ensure_channel_first
+from monai.utils.misc import ensure_channel_first

• After moving, remove the old definition in monai/inferers/utils.py or re-export it there if needed by inferers.
• Verify no import cycles between metrics and inferers after the change.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In monai/metrics/meandice.py around lines 18-19, the import from
monai.inferers.utils should be replaced because ensure_channel_first is a
generic tensor-layout helper; move its implementation from
monai/inferers/utils.py into a neutral utils module under monai/utils (e.g.,
monai/utils/misc.py or monai/utils/tensor_utils.py), update meandice.py to
import ensure_channel_first from the new module, and either remove the old
implementation from monai/inferers/utils.py or add a thin re-export there to
preserve backward compatibility; after changes, run import graph checks and
tests to ensure no import cycles between metrics and inferers.

Comment on lines +314 to +333
# --- Normalize layout to channel-first (N, C, spatial...) ---
# Prefer a strong signal when available.
if self.num_classes is not None:
y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes)
else:
# First pass: heuristic only.
y_pred, _ = ensure_channel_first(y_pred)
# Fallback: if implausible vs y's layout, retry with a hint from y's last dim.
if y.ndim == y_pred.ndim:
plausible = {1, y.shape[1], y.shape[-1]}
if y_pred.shape[1] not in plausible:
y_pred, _ = ensure_channel_first(y_pred, channel_hint=int(y.shape[-1]))

# Infer channels after normalization (or use provided).
n_ch = self.num_classes or y_pred.shape[1]

# Normalize y if it plausibly is channel-last.
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
y, _ = ensure_channel_first(y, channel_hint=n_ch)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Fix NHWC/NCHW ambiguity; fail fast when num_classes is missing; strengthen fallback

As written, ambiguous shapes can slip through (e.g., NHWC where H==W==C>32); the heuristic may preserve NHWC and silently misinterpret channels. Add an explicit ambiguity check when num_classes is None, and only retry with a y-informed hint when plausible. Also error on the classic “y_pred is label map (C==1) but y is one-hot CL with C>1” unless num_classes is provided.

         # --- Normalize layout to channel-first (N, C, spatial...) ---
         # Prefer a strong signal when available.
-        if self.num_classes is not None:
-                y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes)
-        else:
-            # First pass: heuristic only.
-            y_pred, _ = ensure_channel_first(y_pred)
-            # Fallback: if implausible vs y's layout, retry with a hint from y's last dim.
-            if y.ndim == y_pred.ndim:
-                plausible = {1, y.shape[1], y.shape[-1]}
-                if y_pred.shape[1] not in plausible:
-                    y_pred, _ = ensure_channel_first(y_pred, channel_hint=int(y.shape[-1]))
+        if self.num_classes is not None:
+            y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes)
+        else:
+            # Detect clearly ambiguous layouts (dim1 == dim-1 > 1) and require num_classes.
+            if y.ndim == y_pred.ndim and y.shape[1] == y.shape[-1] and y.shape[1] > 1:
+                raise ValueError(
+                    "Ambiguous channel layout: y has identical sizes at dim=1 and dim=-1; "
+                    "please provide `num_classes` to disambiguate."
+                )
+            # First pass: heuristic only.
+            y_pred, _ = ensure_channel_first(y_pred)
+            # Fallback: if implausible vs y's layout, retry with a hint from y (prefer last dim).
+            if y.ndim == y_pred.ndim:
+                plausible = {1, y.shape[1], y.shape[-1]}
+                if y_pred.shape[1] not in plausible:
+                    hint = int(y.shape[-1]) if y.shape[-1] in plausible else int(y.shape[1])
+                    y_pred, _ = ensure_channel_first(y_pred, channel_hint=hint)
 
         # Infer channels after normalization (or use provided).
         n_ch = self.num_classes or y_pred.shape[1]
 
         # Normalize y if it plausibly is channel-last.
-        if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
+        if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
             y, _ = ensure_channel_first(y, channel_hint=n_ch)
+        # Inconsistent pair: label-map y_pred (C==1) vs one-hot y (last dim > 1) without num_classes.
+        if (
+            self.num_classes is None
+            and y.ndim == y_pred.ndim
+            and y_pred.shape[1] == 1
+            and y.shape[-1] > 1
+            and y.shape[1] == 1
+        ):
+            raise ValueError(
+                "Inconsistent inputs: y_pred appears to be a label map (C==1) while y is one-hot "
+                "channel-last (last dim > 1) but `num_classes` is not provided. "
+                "Specify `num_classes` to compute per-class Dice."
+            )

Follow-ups:

  • Add tests for:
    • NHWC with H==W==C (e.g., 64) and num_classes=None → expect ValueError.
    • NHWC with large C and num_classes=C → matches NCHW result.
    • y_pred: [N,1,H,W], y: [N,H,W,C], num_classes=C → matches NCHW result; and same case with num_classes=None → expect ValueError.

I can draft the unit tests if you want them in tests/metrics/test_dice_layouts.py.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# --- Normalize layout to channel-first (N, C, spatial...) ---
# Prefer a strong signal when available.
if self.num_classes is not None:
y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes)
else:
# First pass: heuristic only.
y_pred, _ = ensure_channel_first(y_pred)
# Fallback: if implausible vs y's layout, retry with a hint from y's last dim.
if y.ndim == y_pred.ndim:
plausible = {1, y.shape[1], y.shape[-1]}
if y_pred.shape[1] not in plausible:
y_pred, _ = ensure_channel_first(y_pred, channel_hint=int(y.shape[-1]))
# Infer channels after normalization (or use provided).
n_ch = self.num_classes or y_pred.shape[1]
# Normalize y if it plausibly is channel-last.
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
y, _ = ensure_channel_first(y, channel_hint=n_ch)
# --- Normalize layout to channel-first (N, C, spatial...) ---
# Prefer a strong signal when available.
if self.num_classes is not None:
y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes)
else:
# Detect clearly ambiguous layouts (dim1 == dim-1 > 1) and require num_classes.
if y.ndim == y_pred.ndim and y.shape[1] == y.shape[-1] and y.shape[1] > 1:
raise ValueError(
"Ambiguous channel layout: y has identical sizes at dim=1 and dim=-1; "
"please provide `num_classes` to disambiguate."
)
# First pass: heuristic only.
y_pred, _ = ensure_channel_first(y_pred)
# Fallback: if implausible vs y's layout, retry with a hint from y (prefer last dim).
if y.ndim == y_pred.ndim:
plausible = {1, y.shape[1], y.shape[-1]}
if y_pred.shape[1] not in plausible:
hint = int(y.shape[-1]) if y.shape[-1] in plausible else int(y.shape[1])
y_pred, _ = ensure_channel_first(y_pred, channel_hint=hint)
# Infer channels after normalization (or use provided).
n_ch = self.num_classes or y_pred.shape[1]
# Normalize y if it plausibly is channel-last.
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
y, _ = ensure_channel_first(y, channel_hint=n_ch)
# Inconsistent pair: label-map y_pred (C==1) vs one-hot y (last dim > 1) without num_classes.
if (
self.num_classes is None
and y.ndim == y_pred.ndim
and y_pred.shape[1] == 1
and y.shape[-1] > 1
and y.shape[1] == 1
):
raise ValueError(
"Inconsistent inputs: y_pred appears to be a label map (C==1) while y is one-hot "
"channel-last (last dim > 1) but `num_classes` is not provided. "
"Specify `num_classes` to compute per-class Dice."
)
🤖 Prompt for AI Agents
In monai/metrics/meandice.py around lines 314 to 333, the current heuristic for
normalizing channel layout can silently accept ambiguous NHWC shapes (e.g.
H==W==C) and misinterpret channels when num_classes is None; update the logic to
detect and fail-fast on ambiguous layouts when num_classes is not provided, only
perform the y-informed retry when the hint is plausible (e.g., y.shape[-1] in
{1, y_pred.shape[-1], expected range}) and otherwise raise a ValueError
describing the ambiguity; additionally, add an explicit check that rejects the
case where y_pred is a label map (channels==1) while y is clearly one-hot
(channels>1) unless num_classes is supplied, and use num_classes to disambiguate
or convert confidently when available.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant