15
15
"""Flow field estimation from SOFIMA."""
16
16
17
17
import dataclasses
18
- from typing import Optional
19
-
18
+ from typing import Any , Sequence
19
+ from connectomics . common import beam_utils
20
20
from connectomics .common import bounding_box
21
21
from connectomics .common import utils
22
- from connectomics .volume import mask
22
+ from connectomics .volume import base
23
+ from connectomics .volume import mask as mask_lib
24
+ from connectomics .volume import metadata
23
25
from connectomics .volume import subvolume
24
26
from connectomics .volume import subvolume_processor
25
27
import dataclasses_json
26
28
import numpy as np
29
+ from scipy import interpolate
27
30
from sofima import flow_field
31
+ from sofima import flow_utils
32
+ from sofima import map_utils
33
+
34
+
35
+ Subvolume = subvolume .Subvolume
36
+ SubvolumeOrMany = Subvolume | list [Subvolume ]
28
37
29
38
30
39
class EstimateFlow (subvolume_processor .SubvolumeProcessor ):
@@ -77,9 +86,9 @@ class EstimateFlowConfig(utils.NPDataClassJsonMixin):
77
86
stride : int
78
87
z_stride : int = 1
79
88
fixed_current : bool = False
80
- mask_configs : Optional [ mask .MaskConfigs ] = None
89
+ mask_configs : mask_lib .MaskConfigs | None = None
81
90
mask_only_for_patch_selection : bool = False
82
- selection_mask_configs : Optional [ mask .MaskConfigs ] = None
91
+ selection_mask_configs : mask_lib .MaskConfigs | None = None
83
92
batch_size : int = 1024
84
93
85
94
_config : EstimateFlowConfig
@@ -143,7 +152,7 @@ def process(self, subvol: subvolume.Subvolume) -> subvolume.Subvolume:
143
152
if config .mask_configs is not None :
144
153
# TODO(blakely): Remove the unused lambda here and below when the external
145
154
# paths support DecoratorSpecs.
146
- initial_mask = mask .build_mask (
155
+ initial_mask = mask_lib .build_mask (
147
156
config .mask_configs , subvol .bbox , lambda x : x
148
157
)
149
158
@@ -160,7 +169,7 @@ def process(self, subvol: subvolume.Subvolume) -> subvolume.Subvolume:
160
169
subvol .bbox .size - xy * config .patch_size + xy * config .stride
161
170
) / scale
162
171
sel_box = bounding_box .BoundingBox (sel_start , sel_size )
163
- sel_mask = mask .build_mask (
172
+ sel_mask = mask_lib .build_mask (
164
173
config .selection_mask_configs , sel_box , lambda x : x
165
174
)
166
175
@@ -259,3 +268,238 @@ def expected_output_box(
259
268
+ self ._config .stride
260
269
) // self ._config .stride
261
270
return bounding_box .BoundingBox (start , size )
271
+
272
+
273
+ # TODO(blakely): Remove references to volinfos in favor of metadata
274
+ class ReconcileAndFilterFlows (subvolume_processor .SubvolumeProcessor ):
275
+ """Filters 4-channel or 3-channel flow volumes.
276
+
277
+ The input flow volume(s) (generated by EstimateFlow) are filtered to
278
+ only retain 'valid' entries fulfilling local consistency and estimation
279
+ confidence criteria. If additional (lower-resolution) flow estimates
280
+ are provided via 'flow_volinfos', they are used to fill any flow
281
+ entries considered 'invalid' after filtering the higher resolution
282
+ results.
283
+ """
284
+
285
+ crop_at_borders = False
286
+
287
+ @dataclasses_json .dataclass_json
288
+ @dataclasses .dataclass (eq = True )
289
+ class ReconcileFlowsConfig (utils .NPDataClassJsonMixin ):
290
+ """Configuration for ReconcileAndFilterFlows.
291
+
292
+ Attributes:
293
+ flow_volinfos: List or comma-separated string of volinfo paths, sorted in
294
+ ascending order of voxel size; a path can optionally be followed by
295
+ ':scale', which defines a divisor to apply to the corresponding flow
296
+ field. If the divisor is not specified, its value is inferred from the
297
+ pixel size ratio between the given flow field and the first flow field
298
+ on the list.
299
+ mask_configs: MaskConfigs proto in text format; masked voxels will be set
300
+ to nan (in both channels)
301
+ min_peak_ratio: See flow_utils.clean_flow.
302
+ min_peak_sharpness: See flow_utils.clean_flow.
303
+ max_magnitude: See flow_utils.clean_flow.
304
+ max_deviation: See flow_utils.clean_flow.
305
+ max_gradient: See flow_utils.clean_flow.
306
+ min_patch_size: See flow_utils.clean_flow.
307
+ multi_section: If generating a multi-section volume, the value of the 3rd
308
+ channel to initialize the output flow with
309
+ base_delta_z: If generating a multi-section volume, the value of the 3rd
310
+ channel to initialize the output flow with
311
+ """
312
+
313
+ flow_volinfos : Sequence [str ] | str | None = None
314
+ mask_configs : str | mask_lib .MaskConfigs | None = None
315
+ min_peak_ratio : float = 1.6
316
+ min_peak_sharpness : float = 1.6
317
+ max_magnitude : float = 40
318
+ max_deviation : float = 10
319
+ max_gradient : float = 40
320
+ min_patch_size : int = 400
321
+ multi_section : bool = False
322
+ base_delta_z : int = 0
323
+
324
+ _config : ReconcileFlowsConfig
325
+
326
+ def __init__ (
327
+ self ,
328
+ config : ReconcileFlowsConfig ,
329
+ input_volinfo_or_metadata : str | metadata .VolumeMetadata | None = None ,
330
+ ):
331
+ """Constructor.
332
+
333
+ Args:
334
+ config: Parameters for ReconcileAndFilterFlows
335
+ input_volinfo_or_metadata: input volume with a voxel size equal or smaller
336
+ than the first volume in the flow_volinfos list
337
+ """
338
+ self ._config = config
339
+
340
+ self ._scales = [None ]
341
+ self ._metadata : list [metadata .VolumeMetadata ] = []
342
+ if input_volinfo_or_metadata is not None :
343
+ self ._metadata .append (self ._get_metadata (input_volinfo_or_metadata ))
344
+ if isinstance (config .flow_volinfos , str ):
345
+ config .flow_volinfos = config .flow_volinfos .split (',' )
346
+
347
+ for path in config .flow_volinfos :
348
+ path , _ , scale = path .partition (':' )
349
+ if scale :
350
+ scale = float (scale )
351
+ else :
352
+ scale = None
353
+
354
+ self ._scales .append (scale )
355
+ self ._metadata .append (self ._get_metadata (path ))
356
+
357
+ # Ensure that the volumes are correctly sorted.
358
+ for a , b in zip (self ._metadata , self ._metadata [1 :]):
359
+ assert a .pixel_size .x <= b .pixel_size .x
360
+ assert a .pixel_size .y <= b .pixel_size .y
361
+ assert a .pixel_size .x / b .pixel_size .x == a .pixel_size .y / b .pixel_size .y
362
+ assert a .pixel_size .z == b .pixel_size .z
363
+
364
+ if config .mask_configs is not None :
365
+ if isinstance (config .mask_configs , str ):
366
+ config .mask_configs = self ._get_mask_configs (config .mask_configs )
367
+
368
+ def _open_volume (self , path : str ) -> base .Volume :
369
+ """Returns a CZYX-shaped ndarray-like object."""
370
+ raise NotImplementedError (
371
+ 'This function needs to be defined in a subclass.'
372
+ )
373
+
374
+ def _get_metadata (self , path ) -> metadata .VolumeMetadata :
375
+ raise NotImplementedError (
376
+ 'This function needs to be defined in a subclass.'
377
+ )
378
+
379
+ def _get_mask_configs (self , mask_configs : str ) -> mask_lib .MaskConfigs :
380
+ raise NotImplementedError (
381
+ 'This function needs to be defined in a subclass.'
382
+ )
383
+
384
+ def _build_mask (
385
+ self ,
386
+ mask_configs : mask_lib .MaskConfigs ,
387
+ box : bounding_box .BoundingBoxBase ,
388
+ ) -> Any :
389
+ """Returns a CZYX-shaped ndarray-like object."""
390
+ raise NotImplementedError (
391
+ 'This function needs to be defined in a subclass.'
392
+ )
393
+
394
+ def num_channels (self , input_channels = 0 ):
395
+ del input_channels
396
+ return 2 if not self ._config .multi_section else 3
397
+
398
+ def process (self , subvol : Subvolume ) -> SubvolumeOrMany :
399
+ box = subvol .bbox
400
+ if self ._config .mask_configs is not None :
401
+ mask = self ._build_mask (self ._config .mask_configs , box )
402
+ else :
403
+ mask = None
404
+
405
+ # Points in image space at which the base (highest resolution) flow
406
+ # is defined. Pixel values are assumed to correspond to the middle
407
+ # point of the pixel.
408
+ qy , qx = np .mgrid [: box .size [1 ], : box .size [0 ]]
409
+ qx = qx + box .start [0 ]
410
+ qy = qy + box .start [1 ]
411
+
412
+ flows = []
413
+ volumes = [self ._open_volume (v ) for v in self ._metadata ]
414
+
415
+ for i , (vol , mag_scale ) in enumerate (zip (volumes , self ._scales )):
416
+ if i > 0 :
417
+ scale = self ._metadata [0 ].pixel_size .x / self ._metadata [i ].pixel_size .x
418
+ assert scale <= 1.0
419
+ read_box = box .scale ((scale , scale , 1 ))
420
+ if scale < 1 :
421
+ read_box = read_box .adjusted_by (
422
+ start = - self ._context [0 ], end = self ._context [1 ]
423
+ )
424
+ read_box = vol .clip_box_to_volume (read_box )
425
+ assert read_box is not None
426
+ else :
427
+ scale = 1
428
+ read_box = box
429
+
430
+ with beam_utils .timer_counter (
431
+ 'reconcile-flows' , 'time-volstore-load-%d' % i
432
+ ):
433
+ flow = vol [read_box .to_slice4d ()]
434
+
435
+ with beam_utils .timer_counter ('reconcile-flows' , 'time-clean-%d' % i ):
436
+ flow = flow_utils .clean_flow (
437
+ flow ,
438
+ self ._config .min_peak_ratio ,
439
+ self ._config .min_peak_sharpness ,
440
+ self ._config .max_magnitude ,
441
+ self ._config .max_deviation ,
442
+ )
443
+
444
+ if i == 0 or scale == 1 :
445
+ if self ._config .multi_section and flow .shape [0 ] != 3 :
446
+ shape = np .array (flow .shape )
447
+ shape [0 ] = 3
448
+ nflow = np .full (shape , np .nan , dtype = flow .dtype )
449
+ nflow [:2 , ...] = flow [:2 , ...]
450
+ nflow [2 , ...][np .isfinite (nflow [0 , ...])] = self ._config .base_delta_z
451
+ flow = nflow
452
+
453
+ flows .append (flow )
454
+ continue
455
+
456
+ # Upsample flow to the base resolution.
457
+ hires_flow = np .zeros_like (flows [0 ])
458
+
459
+ oy , ox = np .ogrid [: read_box .size [1 ], : read_box .size [0 ]]
460
+ ox = ox + read_box .start [0 ]
461
+ oy = oy + read_box .start [1 ]
462
+ ox = (ox / scale ).ravel ()
463
+ oy = (oy / scale ).ravel ()
464
+
465
+ if mag_scale is None :
466
+ mag_scale = scale
467
+
468
+ with beam_utils .timer_counter ('reconcile-flows' , 'time-upsample-%d' % i ):
469
+ for z in range (flow .shape [1 ]):
470
+ rgi = interpolate .RegularGridInterpolator (
471
+ (oy , ox ), flow [0 , z , ...], method = 'nearest' , bounds_error = False
472
+ )
473
+ invalid_mask = np .isnan (rgi ((qy , qx )))
474
+
475
+ # We want to upsample the spatial components of the flow with
476
+ # at least linear interpolation. Doing so with RegularGridInterpolator
477
+ # in the presence of invalid entries (NaN) will cause the invalid
478
+ # regions to grow beyond what 'nearest' upsampling would generate.
479
+ # To avoid this, we use a resampling scheme with interpolation and
480
+ # mask out invalid entries as if the field was resampled in
481
+ # the 'nearest' interpolation mode.
482
+ resampled = map_utils .resample_map (
483
+ flow [:2 , z : z + 1 , ...], read_box , box , 1 / scale , 1 #
484
+ )
485
+ hires_flow [:2 , z : z + 1 , ...] = resampled / mag_scale
486
+ hires_flow [0 , z , ...][invalid_mask ] = np .nan
487
+ hires_flow [1 , z , ...][invalid_mask ] = np .nan
488
+
489
+ for c in range (2 , self .num_channels ()):
490
+ rgi = interpolate .RegularGridInterpolator (
491
+ (oy , ox ), flow [c , z , ...], method = 'nearest' , bounds_error = False
492
+ )
493
+ hires_flow [c , z , ...] = rgi ((qy , qx )).astype (np .float32 )
494
+
495
+ if mask is not None :
496
+ flow_utils .apply_mask (hires_flow , mask )
497
+ flows .append (hires_flow )
498
+
499
+ ret = flow_utils .reconcile_flows (
500
+ flows ,
501
+ self ._config .max_gradient ,
502
+ self ._config .max_deviation ,
503
+ self ._config .min_patch_size ,
504
+ )
505
+ return self .crop_box_and_data (box , ret )
0 commit comments