-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Fix: add ensure_channel_first in utils and integrate into DiceHelper (refs #8366) #8532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
f063cc2
dc721d4
99f7e30
95c58a0
ca32652
7bb9926
488e104
2f9254f
69b96e2
1080987
a360570
51eeb65
6d52c3c
12a34b7
adcbfdb
0fea070
c52922c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -165,3 +165,6 @@ runs | |
*.pth | ||
|
||
*zarr/* | ||
issue38366/ | ||
|
||
issue8366/ |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -15,6 +15,8 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from monai.metrics.utils import do_metric_reduction | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from monai.utils import MetricReduction, deprecated_arg | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from monai.inferers.utils import ensure_channel_first | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from .metric import CumulativeIterationMetric | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -123,6 +125,7 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
num_classes=self.num_classes, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Compute the dice value using ``DiceHelper``. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -306,6 +309,28 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# --- Normalize layout to channel-first (N, C, spatial...) --- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Prefer a strong signal when available. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.num_classes is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# First pass: heuristic only. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
y_pred, _ = ensure_channel_first(y_pred) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Fallback: if implausible vs y's layout, retry with a hint from y's last dim. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if y.ndim == y_pred.ndim: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
plausible = {1, y.shape[1], y.shape[-1]} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if y_pred.shape[1] not in plausible: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
y_pred, _ = ensure_channel_first(y_pred, channel_hint=int(y.shape[-1])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Infer channels after normalization (or use provided). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
n_ch = self.num_classes or y_pred.shape[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Normalize y if it plausibly is channel-last. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
y, _ = ensure_channel_first(y, channel_hint=n_ch) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+314
to
+333
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Fix NHWC/NCHW ambiguity; fail fast when num_classes is missing; strengthen fallback As written, ambiguous shapes can slip through (e.g., NHWC where H==W==C>32); the heuristic may preserve NHWC and silently misinterpret channels. Add an explicit ambiguity check when num_classes is None, and only retry with a y-informed hint when plausible. Also error on the classic “y_pred is label map (C==1) but y is one-hot CL with C>1” unless num_classes is provided. # --- Normalize layout to channel-first (N, C, spatial...) ---
# Prefer a strong signal when available.
- if self.num_classes is not None:
- y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes)
- else:
- # First pass: heuristic only.
- y_pred, _ = ensure_channel_first(y_pred)
- # Fallback: if implausible vs y's layout, retry with a hint from y's last dim.
- if y.ndim == y_pred.ndim:
- plausible = {1, y.shape[1], y.shape[-1]}
- if y_pred.shape[1] not in plausible:
- y_pred, _ = ensure_channel_first(y_pred, channel_hint=int(y.shape[-1]))
+ if self.num_classes is not None:
+ y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes)
+ else:
+ # Detect clearly ambiguous layouts (dim1 == dim-1 > 1) and require num_classes.
+ if y.ndim == y_pred.ndim and y.shape[1] == y.shape[-1] and y.shape[1] > 1:
+ raise ValueError(
+ "Ambiguous channel layout: y has identical sizes at dim=1 and dim=-1; "
+ "please provide `num_classes` to disambiguate."
+ )
+ # First pass: heuristic only.
+ y_pred, _ = ensure_channel_first(y_pred)
+ # Fallback: if implausible vs y's layout, retry with a hint from y (prefer last dim).
+ if y.ndim == y_pred.ndim:
+ plausible = {1, y.shape[1], y.shape[-1]}
+ if y_pred.shape[1] not in plausible:
+ hint = int(y.shape[-1]) if y.shape[-1] in plausible else int(y.shape[1])
+ y_pred, _ = ensure_channel_first(y_pred, channel_hint=hint)
# Infer channels after normalization (or use provided).
n_ch = self.num_classes or y_pred.shape[1]
# Normalize y if it plausibly is channel-last.
- if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
+ if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
y, _ = ensure_channel_first(y, channel_hint=n_ch)
+ # Inconsistent pair: label-map y_pred (C==1) vs one-hot y (last dim > 1) without num_classes.
+ if (
+ self.num_classes is None
+ and y.ndim == y_pred.ndim
+ and y_pred.shape[1] == 1
+ and y.shape[-1] > 1
+ and y.shape[1] == 1
+ ):
+ raise ValueError(
+ "Inconsistent inputs: y_pred appears to be a label map (C==1) while y is one-hot "
+ "channel-last (last dim > 1) but `num_classes` is not provided. "
+ "Specify `num_classes` to compute per-class Dice."
+ ) Follow-ups:
I can draft the unit tests if you want them in tests/metrics/test_dice_layouts.py. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
_apply_argmax, _threshold = self.apply_argmax, self.threshold | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.num_classes is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Validate import location and layering; avoid metrics → inferers dependency if possible
Importing ensure_channel_first from monai.inferers.utils couples metrics to inferers. Confirm final placement (inferers.utils vs metrics.utils as mentioned in the PR) and ensure no circulars. If this is a generic tensor-layout utility, consider relocating to a neutral utils module to reduce layering risk.
Run to verify single source and usages:
🏁 Script executed:
Length of output: 443
🏁 Script executed:
Length of output: 195
🏁 Script executed:
Length of output: 318
Relocate ensure_channel_first to a neutral utils module
Avoid coupling metrics to inferers by keeping generic tensor-layout helpers out of
monai.inferers
. Sinceensure_channel_first
is a generic utility, move it to an existing or new module undermonai/utils
(e.g.monai/utils/misc.py
ormonai/utils/tensor_utils.py
), then update imports accordingly. This reduces layering risk and prevents potential circular dependencies.• File: monai/metrics/meandice.py (line 18)
• New location: monai/utils/(misc.py or tensor_utils.py)
Example diff:
• After moving, remove the old definition in
monai/inferers/utils.py
or re-export it there if needed by inferers.• Verify no import cycles between
metrics
andinferers
after the change.🤖 Prompt for AI Agents