Skip to content

Commit cafc1fe

Browse files
Pytorch 2.8 Support (#8530)
Fixes #8529. ### Description This adds support for PyTorch 2.8. This required a few minor changes to how determinism worked, and skipping some tests which mysteriously started failing under Windows which seems related to GLOO under 2.8. Mypy seems to have updated with new things it found and a number of spurious typing issues so many more files had to be fixed for that. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b92b2ce commit cafc1fe

29 files changed

+116
-95
lines changed

.github/workflows/pythonapp-min.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ jobs:
124124
strategy:
125125
fail-fast: false
126126
matrix:
127-
pytorch-version: ['2.4.1', '2.5.1', '2.6.0', '2.7.1']
127+
pytorch-version: ['2.5.1', '2.6.0', '2.7.1', '2.8.0']
128128
timeout-minutes: 40
129129
steps:
130130
- uses: actions/checkout@v4

.github/workflows/pythonapp.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ jobs:
9494
- if: runner.os == 'windows'
9595
name: Install torch cpu from pytorch.org (Windows only)
9696
run: |
97-
python -m pip install torch==2.4.1 torchvision==0.19.1+cpu --index-url https://download.pytorch.org/whl/cpu
97+
python -m pip install torch==2.5.1 torchvision==0.20.1+cpu --index-url https://download.pytorch.org/whl/cpu
9898
- if: runner.os == 'Linux'
9999
name: Install itk pre-release (Linux only)
100100
run: |
@@ -103,7 +103,7 @@ jobs:
103103
- name: Install the dependencies
104104
run: |
105105
python -m pip install --user --upgrade pip wheel
106-
python -m pip install torch==2.4.1 torchvision==0.19.1
106+
python -m pip install torch==2.5.1 torchvision==0.20.1
107107
cat "requirements-dev.txt"
108108
python -m pip install -r requirements-dev.txt
109109
python -m pip list
@@ -155,7 +155,7 @@ jobs:
155155
# install the latest pytorch for testing
156156
# however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated
157157
# fresh torch installation according to pyproject.toml
158-
python -m pip install torch>=2.4.1 torchvision
158+
python -m pip install torch>=2.5.1 torchvision
159159
- name: Check packages
160160
run: |
161161
pip uninstall monai

monai/apps/detection/metrics/coco.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def _compute_statistics(self, results_list: list[dict[int, dict[str, np.ndarray]
457457
dt_ignores = np.concatenate([r["dtIgnore"][:, 0:max_det] for r in results], axis=1)[:, inds]
458458
self.check_number_of_iou(dt_matches, dt_ignores)
459459
gt_ignore = np.concatenate([r["gtIgnore"] for r in results])
460-
num_gt = np.count_nonzero(gt_ignore == 0) # number of ground truth boxes (non ignored)
460+
num_gt = int(np.count_nonzero(gt_ignore == 0)) # number of ground truth boxes (non ignored)
461461
if num_gt == 0:
462462
logger.warning(f"WARNING, no gt found for coco metric for class {cls_i}")
463463
continue
@@ -523,13 +523,12 @@ def _compute_stats_single_threshold(
523523
recall = 0
524524

525525
# array where precision values nearest to given recall th are saved
526-
precision = np.zeros((num_recall_th,))
526+
precision = [0.0] * num_recall_th
527527
# save scores for corresponding recall value in here
528528
th_scores = np.zeros((num_recall_th,))
529529
# numpy is slow without cython optimization for accessing elements
530530
# use python array gets significant speed improvement
531531
pr = pr.tolist()
532-
precision = precision.tolist()
533532

534533
# smooth precision curve (create box shape)
535534
for i in range(len(tp) - 1, 0, -1):

monai/apps/detection/utils/box_coder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,10 @@ def decode_single(self, rel_codes: Tensor, reference_boxes: Tensor) -> Tensor:
210210
offset = reference_boxes.shape[-1]
211211

212212
pred_boxes = []
213-
boxes_cccwhd = convert_box_mode(reference_boxes, src_mode=StandardMode, dst_mode=CenterSizeMode)
213+
boxes_cccwhd: torch.Tensor = convert_box_mode(
214+
reference_boxes, src_mode=StandardMode, dst_mode=CenterSizeMode
215+
) # type: ignore[assignment]
216+
214217
for axis in range(self.spatial_dims):
215218
whd_axis = boxes_cccwhd[:, axis + self.spatial_dims]
216219
ctr_xyz_axis = boxes_cccwhd[:, axis]

monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,9 @@ def _apply_down_blocks(self, h, emb, context, down_block_additional_residuals):
358358

359359
def _apply_up_blocks(self, h, emb, context, down_block_res_samples):
360360
for upsample_block in self.up_blocks:
361-
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
362-
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
361+
idx: int = -len(upsample_block.resnets) # type: ignore
362+
res_samples = down_block_res_samples[idx:]
363+
down_block_res_samples = down_block_res_samples[:idx]
363364
h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)
364365

365366
return h

monai/data/box_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -811,9 +811,9 @@ def _box_inter_union(
811811

812812
# compute size for the intersection region for the NxM combinations
813813
wh = (rb - lt + TO_REMOVE).clamp(min=0) # (N,M,spatial_dims)
814-
inter = torch.prod(wh, dim=-1, keepdim=False) # (N,M)
814+
inter: torch.Tensor = torch.prod(wh, dim=-1, keepdim=False) # (N,M)
815815

816-
union = area1[:, None] + area2 - inter
816+
union: torch.Tensor = area1[:, None] + area2 - inter # type: ignore
817817
return inter, union
818818

819819

@@ -981,7 +981,7 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr
981981
wh = (rb - lt + TO_REMOVE).clamp(min=0) # (N,spatial_dims)
982982
enclosure = torch.prod(wh, dim=-1, keepdim=False) # (N,)
983983

984-
giou_t = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps)
984+
giou_t: torch.Tensor = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps) # type: ignore
985985
giou_t = giou_t.to(dtype=box_dtype) # (N,spatial_dims)
986986
if torch.isnan(giou_t).any() or torch.isinf(giou_t).any():
987987
raise ValueError("Box GIoU is NaN or Inf.")

monai/data/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1353,7 +1353,7 @@ def __len__(self) -> int:
13531353
return len(self.dataset)
13541354

13551355
def randomize(self, data: Any | None = None) -> None:
1356-
self._seed = self.R.randint(MAX_SEED, dtype="uint32")
1356+
self._seed = int(self.R.randint(MAX_SEED, dtype="uint32"))
13571357

13581358
def __getitem__(self, index: int):
13591359
self.randomize()

monai/data/image_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __len__(self) -> int:
9797
return len(self.image_files)
9898

9999
def randomize(self, data: Any | None = None) -> None:
100-
self._seed = self.R.randint(MAX_SEED, dtype="uint32")
100+
self._seed = int(self.R.randint(MAX_SEED, dtype="uint32"))
101101

102102
def __getitem__(self, index: int):
103103
self.randomize()

monai/data/image_reader.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ def _combine_dicom_series(self, data: Iterable, filenames: Sequence[PathLike]):
580580
shape = first_array.shape
581581
spacing = getattr(first_slice, "PixelSpacing", [1.0] * len(shape))
582582
prev_pos = getattr(first_slice, "ImagePositionPatient", (0.0, 0.0, 0.0))[2]
583-
stack_array = [first_array]
583+
stack_array_list: list = [first_array]
584584
for idx in range(1, len(slices)):
585585
slc_array = self._get_array_data(slices[idx][0], slices[idx][1])
586586
slc_shape = slc_array.shape
@@ -592,22 +592,24 @@ def _combine_dicom_series(self, data: Iterable, filenames: Sequence[PathLike]):
592592
warnings.warn(f"the list contains slices that have different shapes {shape} and {slc_shape}.")
593593
average_distance += abs(prev_pos - slc_pos)
594594
prev_pos = slc_pos
595-
stack_array.append(slc_array)
595+
stack_array_list.append(slc_array)
596596

597597
if len(slices) > 1:
598598
average_distance /= len(slices) - 1
599599
spacing.append(average_distance)
600600
if self.to_gpu:
601-
stack_array = cp.stack(stack_array, axis=-1)
601+
stack_array = cp.stack(stack_array_list, axis=-1)
602602
else:
603-
stack_array = np.stack(stack_array, axis=-1)
603+
stack_array = np.stack(stack_array_list, axis=-1)
604+
605+
del stack_array_list[:]
604606
stack_metadata = self._get_meta_dict(first_slice)
605607
stack_metadata["spacing"] = np.asarray(spacing)
606608
if hasattr(slices[-1][0], "ImagePositionPatient"):
607609
stack_metadata["lastImagePositionPatient"] = np.asarray(slices[-1][0].ImagePositionPatient)
608610
stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape + (len(slices),)
609611
else:
610-
stack_array = stack_array[0]
612+
stack_array = stack_array_list[0]
611613
stack_metadata = self._get_meta_dict(first_slice)
612614
stack_metadata["spacing"] = np.asarray(spacing)
613615
stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape

monai/networks/nets/diffusion_model_unet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,8 +1795,9 @@ def forward(
17951795

17961796
# 6. up
17971797
for upsample_block in self.up_blocks:
1798-
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1799-
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1798+
idx: int = -len(upsample_block.resnets) # type: ignore
1799+
res_samples = down_block_res_samples[idx:]
1800+
down_block_res_samples = down_block_res_samples[:idx]
18001801
h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)
18011802

18021803
# 7. output block

0 commit comments

Comments
 (0)