Skip to content

Commit 4b3a011

Browse files
mjanuszcopybara-github
authored andcommitted
Add support for 3d stride in more functions.
PiperOrigin-RevId: 674922489
1 parent 62733d0 commit 4b3a011

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

map_utils.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
"""
5454

5555
import collections
56-
from typing import Optional, Sequence, Union
56+
from typing import Sequence
5757
from connectomics.common import bounding_box
5858
import jax
5959
import jax.numpy as jnp
@@ -63,8 +63,8 @@
6363
from scipy import spatial
6464

6565

66-
StrideZYX = Union[float, Sequence[float]]
67-
ShapeZYX = Union[tuple[int, int], tuple[int, int, int]]
66+
StrideZYX = float | Sequence[float]
67+
ShapeZYX = tuple[int, int] | tuple[int, int, int]
6868

6969

7070
def _interpolate_points(
@@ -150,7 +150,7 @@ def _identity_map_absolute(
150150
def to_absolute(
151151
coord_map: np.ndarray,
152152
stride: StrideZYX,
153-
box: Optional[bounding_box.BoundingBox] = None,
153+
box: bounding_box.BoundingBox | None = None,
154154
) -> np.ndarray:
155155
"""Converts a coordinate map from relative to absolute representation.
156156
@@ -190,7 +190,7 @@ def to_absolute(
190190
def to_relative(
191191
coord_map: np.ndarray,
192192
stride: StrideZYX,
193-
box: Optional[bounding_box.BoundingBox] = None,
193+
box: bounding_box.BoundingBox | None = None,
194194
) -> np.ndarray:
195195
"""Converts a coordinate map from absolute to relative representation.
196196
@@ -308,7 +308,7 @@ def outer_box(
308308
coord_map: np.ndarray,
309309
box: bounding_box.BoundingBox,
310310
stride: StrideZYX,
311-
target_len: Optional[StrideZYX] = None,
311+
target_len: StrideZYX | None = None,
312312
) -> bounding_box.BoundingBox:
313313
"""Returns a bounding box covering all target nodes.
314314
@@ -343,7 +343,7 @@ def outer_box(
343343

344344

345345
def inner_box(
346-
coord_map: np.ndarray, box: bounding_box.BoundingBox, stride: float
346+
coord_map: np.ndarray, box: bounding_box.BoundingBox, stride: StrideZYX
347347
) -> bounding_box.BoundingBox:
348348
"""Returns a box within which all nodes are mapped to by coord map.
349349
@@ -356,7 +356,9 @@ def inner_box(
356356
bounding box, all (u, v[, w]) points contained within which have
357357
an entry in the (x, y[, z]) -> (u, v[, w]) map
358358
"""
359-
assert coord_map.shape[0] in (2, 3)
359+
dim = coord_map.shape[0]
360+
assert dim in (2, 3)
361+
stride = _as_vec(stride, dim)
360362

361363
# Part of the map might be invalid, in which case we extrapolate
362364
# in order to get a fully valid array.
@@ -366,21 +368,21 @@ def inner_box(
366368
y0 = np.max(np.min(int_map[1, ...], axis=-2))
367369
y1 = np.min(np.max(int_map[1, ...], axis=-2))
368370

369-
x0 = int(-(-x0 // stride))
370-
y0 = int(-(-y0 // stride))
371-
x1 = x1 // stride
372-
y1 = y1 // stride
371+
x0 = int(-(-x0 // stride[-1]))
372+
y0 = int(-(-y0 // stride[-2]))
373+
x1 = x1 // stride[-1]
374+
y1 = y1 // stride[-2]
373375

374-
if coord_map.shape[0] == 2:
376+
if dim == 2:
375377
return bounding_box.BoundingBox(
376378
start=(x0, y0, box.start[2]),
377379
size=(x1 - x0 + 1, y1 - y0 + 1, box.size[2]),
378380
)
379381

380382
z0 = np.max(np.min(int_map[2, ...], axis=-3))
381383
z1 = np.min(np.max(int_map[2, ...], axis=-3))
382-
z0 = int(-(-z0 // stride))
383-
z1 = z1 // stride
384+
z0 = int(-(-z0 // stride[0]))
385+
z1 = z1 // stride[0]
384386

385387
return bounding_box.BoundingBox(
386388
start=(x0, y0, z0), size=(x1 - x0 + 1, y1 - y0 + 1, z1 - z0 + 1)
@@ -736,7 +738,7 @@ def mask_irregular(
736738
coord_map: np.ndarray,
737739
stride: float,
738740
frac: float,
739-
max_frac: Optional[float] = None,
741+
max_frac: float | None = None,
740742
dilation_iters: int = 1,
741743
) -> np.ndarray:
742744
"""Masks stretched/folded parts of the map.

processor/maps.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,13 @@ class InvertMap(subvolume_processor.SubvolumeProcessor):
323323
crop_at_borders = False
324324
output_num = subvolume_processor.OutputNums.MULTI
325325

326-
def __init__(self, stride: float, crop_output=True, input_volinfo=None):
326+
def __init__(
327+
self, stride: map_utils.StrideZYX, crop_output=True, input_volinfo=None
328+
):
327329
"""Constructor.
328330
329331
Args:
330-
stride: XY stride of the coordinate map
332+
stride: [Z]YX stride of the coordinate map
331333
crop_output: if False, outputs data for the input box instead of the inner
332334
box of the map; a typical use case is when inverting data for a complete
333335
section in which case there are no other work items that could provide

0 commit comments

Comments
 (0)