Skip to content

Commit 08b1cfe

Browse files
timblakelycopybara-github
authored andcommitted
Add ReconcileAndFilterFlows processor.
PiperOrigin-RevId: 673524211
1 parent 73efc8c commit 08b1cfe

File tree

1 file changed

+251
-7
lines changed

1 file changed

+251
-7
lines changed

processor/flow.py

Lines changed: 251 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,25 @@
1515
"""Flow field estimation from SOFIMA."""
1616

1717
import dataclasses
18-
from typing import Optional
19-
18+
from typing import Any, Sequence
19+
from connectomics.common import beam_utils
2020
from connectomics.common import bounding_box
2121
from connectomics.common import utils
22-
from connectomics.volume import mask
22+
from connectomics.volume import base
23+
from connectomics.volume import mask as mask_lib
24+
from connectomics.volume import metadata
2325
from connectomics.volume import subvolume
2426
from connectomics.volume import subvolume_processor
2527
import dataclasses_json
2628
import numpy as np
29+
from scipy import interpolate
2730
from sofima import flow_field
31+
from sofima import flow_utils
32+
from sofima import map_utils
33+
34+
35+
Subvolume = subvolume.Subvolume
36+
SubvolumeOrMany = Subvolume | list[Subvolume]
2837

2938

3039
class EstimateFlow(subvolume_processor.SubvolumeProcessor):
@@ -77,9 +86,9 @@ class EstimateFlowConfig(utils.NPDataClassJsonMixin):
7786
stride: int
7887
z_stride: int = 1
7988
fixed_current: bool = False
80-
mask_configs: Optional[mask.MaskConfigs] = None
89+
mask_configs: mask_lib.MaskConfigs | None = None
8190
mask_only_for_patch_selection: bool = False
82-
selection_mask_configs: Optional[mask.MaskConfigs] = None
91+
selection_mask_configs: mask_lib.MaskConfigs | None = None
8392
batch_size: int = 1024
8493

8594
_config: EstimateFlowConfig
@@ -143,7 +152,7 @@ def process(self, subvol: subvolume.Subvolume) -> subvolume.Subvolume:
143152
if config.mask_configs is not None:
144153
# TODO(blakely): Remove the unused lambda here and below when the external
145154
# paths support DecoratorSpecs.
146-
initial_mask = mask.build_mask(
155+
initial_mask = mask_lib.build_mask(
147156
config.mask_configs, subvol.bbox, lambda x: x
148157
)
149158

@@ -160,7 +169,7 @@ def process(self, subvol: subvolume.Subvolume) -> subvolume.Subvolume:
160169
subvol.bbox.size - xy * config.patch_size + xy * config.stride
161170
) / scale
162171
sel_box = bounding_box.BoundingBox(sel_start, sel_size)
163-
sel_mask = mask.build_mask(
172+
sel_mask = mask_lib.build_mask(
164173
config.selection_mask_configs, sel_box, lambda x: x
165174
)
166175

@@ -259,3 +268,238 @@ def expected_output_box(
259268
+ self._config.stride
260269
) // self._config.stride
261270
return bounding_box.BoundingBox(start, size)
271+
272+
273+
# TODO(blakely): Remove references to volinfos in favor of metadata
274+
class ReconcileAndFilterFlows(subvolume_processor.SubvolumeProcessor):
275+
"""Filters 4-channel or 3-channel flow volumes.
276+
277+
The input flow volume(s) (generated by EstimateFlow) are filtered to
278+
only retain 'valid' entries fulfilling local consistency and estimation
279+
confidence criteria. If additional (lower-resolution) flow estimates
280+
are provided via 'flow_volinfos', they are used to fill any flow
281+
entries considered 'invalid' after filtering the higher resolution
282+
results.
283+
"""
284+
285+
crop_at_borders = False
286+
287+
@dataclasses_json.dataclass_json
288+
@dataclasses.dataclass(eq=True)
289+
class ReconcileFlowsConfig(utils.NPDataClassJsonMixin):
290+
"""Configuration for ReconcileAndFilterFlows.
291+
292+
Attributes:
293+
flow_volinfos: List or comma-separated string of volinfo paths, sorted in
294+
ascending order of voxel size; a path can optionally be followed by
295+
':scale', which defines a divisor to apply to the corresponding flow
296+
field. If the divisor is not specified, its value is inferred from the
297+
pixel size ratio between the given flow field and the first flow field
298+
on the list.
299+
mask_configs: MaskConfigs proto in text format; masked voxels will be set
300+
to nan (in both channels)
301+
min_peak_ratio: See flow_utils.clean_flow.
302+
min_peak_sharpness: See flow_utils.clean_flow.
303+
max_magnitude: See flow_utils.clean_flow.
304+
max_deviation: See flow_utils.clean_flow.
305+
max_gradient: See flow_utils.clean_flow.
306+
min_patch_size: See flow_utils.clean_flow.
307+
multi_section: If generating a multi-section volume, the value of the 3rd
308+
channel to initialize the output flow with
309+
base_delta_z: If generating a multi-section volume, the value of the 3rd
310+
channel to initialize the output flow with
311+
"""
312+
313+
flow_volinfos: Sequence[str] | str | None = None
314+
mask_configs: str | mask_lib.MaskConfigs | None = None
315+
min_peak_ratio: float = 1.6
316+
min_peak_sharpness: float = 1.6
317+
max_magnitude: float = 40
318+
max_deviation: float = 10
319+
max_gradient: float = 40
320+
min_patch_size: int = 400
321+
multi_section: bool = False
322+
base_delta_z: int = 0
323+
324+
_config: ReconcileFlowsConfig
325+
326+
def __init__(
327+
self,
328+
config: ReconcileFlowsConfig,
329+
input_volinfo_or_metadata: str | metadata.VolumeMetadata | None = None,
330+
):
331+
"""Constructor.
332+
333+
Args:
334+
config: Parameters for ReconcileAndFilterFlows
335+
input_volinfo_or_metadata: input volume with a voxel size equal or smaller
336+
than the first volume in the flow_volinfos list
337+
"""
338+
self._config = config
339+
340+
self._scales = [None]
341+
self._metadata: list[metadata.VolumeMetadata] = []
342+
if input_volinfo_or_metadata is not None:
343+
self._metadata.append(self._get_metadata(input_volinfo_or_metadata))
344+
if isinstance(config.flow_volinfos, str):
345+
config.flow_volinfos = config.flow_volinfos.split(',')
346+
347+
for path in config.flow_volinfos:
348+
path, _, scale = path.partition(':')
349+
if scale:
350+
scale = float(scale)
351+
else:
352+
scale = None
353+
354+
self._scales.append(scale)
355+
self._metadata.append(self._get_metadata(path))
356+
357+
# Ensure that the volumes are correctly sorted.
358+
for a, b in zip(self._metadata, self._metadata[1:]):
359+
assert a.pixel_size.x <= b.pixel_size.x
360+
assert a.pixel_size.y <= b.pixel_size.y
361+
assert a.pixel_size.x / b.pixel_size.x == a.pixel_size.y / b.pixel_size.y
362+
assert a.pixel_size.z == b.pixel_size.z
363+
364+
if config.mask_configs is not None:
365+
if isinstance(config.mask_configs, str):
366+
config.mask_configs = self._get_mask_configs(config.mask_configs)
367+
368+
def _open_volume(self, path: str) -> base.Volume:
369+
"""Returns a CZYX-shaped ndarray-like object."""
370+
raise NotImplementedError(
371+
'This function needs to be defined in a subclass.'
372+
)
373+
374+
def _get_metadata(self, path) -> metadata.VolumeMetadata:
375+
raise NotImplementedError(
376+
'This function needs to be defined in a subclass.'
377+
)
378+
379+
def _get_mask_configs(self, mask_configs: str) -> mask_lib.MaskConfigs:
380+
raise NotImplementedError(
381+
'This function needs to be defined in a subclass.'
382+
)
383+
384+
def _build_mask(
385+
self,
386+
mask_configs: mask_lib.MaskConfigs,
387+
box: bounding_box.BoundingBoxBase,
388+
) -> Any:
389+
"""Returns a CZYX-shaped ndarray-like object."""
390+
raise NotImplementedError(
391+
'This function needs to be defined in a subclass.'
392+
)
393+
394+
def num_channels(self, input_channels=0):
395+
del input_channels
396+
return 2 if not self._config.multi_section else 3
397+
398+
def process(self, subvol: Subvolume) -> SubvolumeOrMany:
399+
box = subvol.bbox
400+
if self._config.mask_configs is not None:
401+
mask = self._build_mask(self._config.mask_configs, box)
402+
else:
403+
mask = None
404+
405+
# Points in image space at which the base (highest resolution) flow
406+
# is defined. Pixel values are assumed to correspond to the middle
407+
# point of the pixel.
408+
qy, qx = np.mgrid[: box.size[1], : box.size[0]]
409+
qx = qx + box.start[0]
410+
qy = qy + box.start[1]
411+
412+
flows = []
413+
volumes = [self._open_volume(v) for v in self._metadata]
414+
415+
for i, (vol, mag_scale) in enumerate(zip(volumes, self._scales)):
416+
if i > 0:
417+
scale = self._metadata[0].pixel_size.x / self._metadata[i].pixel_size.x
418+
assert scale <= 1.0
419+
read_box = box.scale((scale, scale, 1))
420+
if scale < 1:
421+
read_box = read_box.adjusted_by(
422+
start=-self._context[0], end=self._context[1]
423+
)
424+
read_box = vol.clip_box_to_volume(read_box)
425+
assert read_box is not None
426+
else:
427+
scale = 1
428+
read_box = box
429+
430+
with beam_utils.timer_counter(
431+
'reconcile-flows', 'time-volstore-load-%d' % i
432+
):
433+
flow = vol[read_box.to_slice4d()]
434+
435+
with beam_utils.timer_counter('reconcile-flows', 'time-clean-%d' % i):
436+
flow = flow_utils.clean_flow(
437+
flow,
438+
self._config.min_peak_ratio,
439+
self._config.min_peak_sharpness,
440+
self._config.max_magnitude,
441+
self._config.max_deviation,
442+
)
443+
444+
if i == 0 or scale == 1:
445+
if self._config.multi_section and flow.shape[0] != 3:
446+
shape = np.array(flow.shape)
447+
shape[0] = 3
448+
nflow = np.full(shape, np.nan, dtype=flow.dtype)
449+
nflow[:2, ...] = flow[:2, ...]
450+
nflow[2, ...][np.isfinite(nflow[0, ...])] = self._config.base_delta_z
451+
flow = nflow
452+
453+
flows.append(flow)
454+
continue
455+
456+
# Upsample flow to the base resolution.
457+
hires_flow = np.zeros_like(flows[0])
458+
459+
oy, ox = np.ogrid[: read_box.size[1], : read_box.size[0]]
460+
ox = ox + read_box.start[0]
461+
oy = oy + read_box.start[1]
462+
ox = (ox / scale).ravel()
463+
oy = (oy / scale).ravel()
464+
465+
if mag_scale is None:
466+
mag_scale = scale
467+
468+
with beam_utils.timer_counter('reconcile-flows', 'time-upsample-%d' % i):
469+
for z in range(flow.shape[1]):
470+
rgi = interpolate.RegularGridInterpolator(
471+
(oy, ox), flow[0, z, ...], method='nearest', bounds_error=False
472+
)
473+
invalid_mask = np.isnan(rgi((qy, qx)))
474+
475+
# We want to upsample the spatial components of the flow with
476+
# at least linear interpolation. Doing so with RegularGridInterpolator
477+
# in the presence of invalid entries (NaN) will cause the invalid
478+
# regions to grow beyond what 'nearest' upsampling would generate.
479+
# To avoid this, we use a resampling scheme with interpolation and
480+
# mask out invalid entries as if the field was resampled in
481+
# the 'nearest' interpolation mode.
482+
resampled = map_utils.resample_map(
483+
flow[:2, z : z + 1, ...], read_box, box, 1 / scale, 1 #
484+
)
485+
hires_flow[:2, z : z + 1, ...] = resampled / mag_scale
486+
hires_flow[0, z, ...][invalid_mask] = np.nan
487+
hires_flow[1, z, ...][invalid_mask] = np.nan
488+
489+
for c in range(2, self.num_channels()):
490+
rgi = interpolate.RegularGridInterpolator(
491+
(oy, ox), flow[c, z, ...], method='nearest', bounds_error=False
492+
)
493+
hires_flow[c, z, ...] = rgi((qy, qx)).astype(np.float32)
494+
495+
if mask is not None:
496+
flow_utils.apply_mask(hires_flow, mask)
497+
flows.append(hires_flow)
498+
499+
ret = flow_utils.reconcile_flows(
500+
flows,
501+
self._config.max_gradient,
502+
self._config.max_deviation,
503+
self._config.min_patch_size,
504+
)
505+
return self.crop_box_and_data(box, ret)

0 commit comments

Comments
 (0)