53
53
"""
54
54
55
55
import collections
56
- from typing import Optional , Sequence , Union
56
+ from typing import Sequence
57
57
from connectomics .common import bounding_box
58
58
import jax
59
59
import jax .numpy as jnp
63
63
from scipy import spatial
64
64
65
65
66
- StrideZYX = Union [ float , Sequence [float ] ]
67
- ShapeZYX = Union [ tuple [int , int ], tuple [int , int , int ] ]
66
+ StrideZYX = float | Sequence [float ]
67
+ ShapeZYX = tuple [int , int ] | tuple [int , int , int ]
68
68
69
69
70
70
def _interpolate_points (
@@ -150,7 +150,7 @@ def _identity_map_absolute(
150
150
def to_absolute (
151
151
coord_map : np .ndarray ,
152
152
stride : StrideZYX ,
153
- box : Optional [ bounding_box .BoundingBox ] = None ,
153
+ box : bounding_box .BoundingBox | None = None ,
154
154
) -> np .ndarray :
155
155
"""Converts a coordinate map from relative to absolute representation.
156
156
@@ -190,7 +190,7 @@ def to_absolute(
190
190
def to_relative (
191
191
coord_map : np .ndarray ,
192
192
stride : StrideZYX ,
193
- box : Optional [ bounding_box .BoundingBox ] = None ,
193
+ box : bounding_box .BoundingBox | None = None ,
194
194
) -> np .ndarray :
195
195
"""Converts a coordinate map from absolute to relative representation.
196
196
@@ -308,7 +308,7 @@ def outer_box(
308
308
coord_map : np .ndarray ,
309
309
box : bounding_box .BoundingBox ,
310
310
stride : StrideZYX ,
311
- target_len : Optional [ StrideZYX ] = None ,
311
+ target_len : StrideZYX | None = None ,
312
312
) -> bounding_box .BoundingBox :
313
313
"""Returns a bounding box covering all target nodes.
314
314
@@ -343,7 +343,7 @@ def outer_box(
343
343
344
344
345
345
def inner_box (
346
- coord_map : np .ndarray , box : bounding_box .BoundingBox , stride : float
346
+ coord_map : np .ndarray , box : bounding_box .BoundingBox , stride : StrideZYX
347
347
) -> bounding_box .BoundingBox :
348
348
"""Returns a box within which all nodes are mapped to by coord map.
349
349
@@ -356,7 +356,9 @@ def inner_box(
356
356
bounding box, all (u, v[, w]) points contained within which have
357
357
an entry in the (x, y[, z]) -> (u, v[, w]) map
358
358
"""
359
- assert coord_map .shape [0 ] in (2 , 3 )
359
+ dim = coord_map .shape [0 ]
360
+ assert dim in (2 , 3 )
361
+ stride = _as_vec (stride , dim )
360
362
361
363
# Part of the map might be invalid, in which case we extrapolate
362
364
# in order to get a fully valid array.
@@ -366,21 +368,21 @@ def inner_box(
366
368
y0 = np .max (np .min (int_map [1 , ...], axis = - 2 ))
367
369
y1 = np .min (np .max (int_map [1 , ...], axis = - 2 ))
368
370
369
- x0 = int (- (- x0 // stride ))
370
- y0 = int (- (- y0 // stride ))
371
- x1 = x1 // stride
372
- y1 = y1 // stride
371
+ x0 = int (- (- x0 // stride [ - 1 ] ))
372
+ y0 = int (- (- y0 // stride [ - 2 ] ))
373
+ x1 = x1 // stride [ - 1 ]
374
+ y1 = y1 // stride [ - 2 ]
373
375
374
- if coord_map . shape [ 0 ] == 2 :
376
+ if dim == 2 :
375
377
return bounding_box .BoundingBox (
376
378
start = (x0 , y0 , box .start [2 ]),
377
379
size = (x1 - x0 + 1 , y1 - y0 + 1 , box .size [2 ]),
378
380
)
379
381
380
382
z0 = np .max (np .min (int_map [2 , ...], axis = - 3 ))
381
383
z1 = np .min (np .max (int_map [2 , ...], axis = - 3 ))
382
- z0 = int (- (- z0 // stride ))
383
- z1 = z1 // stride
384
+ z0 = int (- (- z0 // stride [ 0 ] ))
385
+ z1 = z1 // stride [ 0 ]
384
386
385
387
return bounding_box .BoundingBox (
386
388
start = (x0 , y0 , z0 ), size = (x1 - x0 + 1 , y1 - y0 + 1 , z1 - z0 + 1 )
@@ -736,7 +738,7 @@ def mask_irregular(
736
738
coord_map : np .ndarray ,
737
739
stride : float ,
738
740
frac : float ,
739
- max_frac : Optional [ float ] = None ,
741
+ max_frac : float | None = None ,
740
742
dilation_iters : int = 1 ,
741
743
) -> np .ndarray :
742
744
"""Masks stretched/folded parts of the map.
0 commit comments