Skip to content

Commit e7accee

Browse files
timblakelycopybara-github
authored andcommitted
Add EstimateMissingFlow stage.
PiperOrigin-RevId: 667656708
1 parent 9d296e3 commit e7accee

File tree

1 file changed

+326
-1
lines changed

1 file changed

+326
-1
lines changed

processor/flow.py

Lines changed: 326 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515
"""Flow field estimation from SOFIMA."""
1616

1717
import dataclasses
18+
import time
1819
from typing import Any, Sequence
20+
21+
from absl import logging
1922
from connectomics.common import beam_utils
2023
from connectomics.common import bounding_box
24+
from connectomics.common import file
2125
from connectomics.common import utils
2226
from connectomics.volume import base
2327
from connectomics.volume import mask as mask_lib
@@ -370,7 +374,7 @@ def __init__(
370374
if isinstance(config.mask_configs, str):
371375
config.mask_configs = self._get_mask_configs(config.mask_configs)
372376

373-
def _open_volume(self, path: str) -> base.Volume:
377+
def _open_volume(self, path: file.PathLike) -> base.Volume:
374378
"""Returns a CZYX-shaped ndarray-like object."""
375379
raise NotImplementedError(
376380
'This function needs to be defined in a subclass.'
@@ -508,3 +512,324 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
508512
self._config.min_patch_size,
509513
)
510514
return self.crop_box_and_data(box, ret)
515+
516+
517+
class EstimateMissingFlow(subvolume_processor.SubvolumeProcessor):
518+
"""Estimates a multi-section flow field.
519+
520+
Takes an existing single-section (2-channel) flow volume as input,
521+
and tries to compute flow vectors which are invalid in the input (NaNs).
522+
"""
523+
524+
@dataclasses_json.dataclass_json
525+
@dataclasses.dataclass(frozen=True)
526+
class EstimateMissingFlowConfig:
527+
"""Configuration for EstimateMissingFlow.
528+
529+
Attributes:
530+
patch_size: Patch size in pixels, divisible by 'stride'
531+
stride: XY stride size in pixels
532+
delta_z: Z stride size in pixels (Δz) for the input volume
533+
max_delta_z: Maximum Z stride with which try to estimate missing flow
534+
vectors
535+
max_attempts: Maximum number of attempts to estimate a flow vector for an
536+
unmasked location
537+
mask_configs: MaskConfigs proto in text format specifying a mask to
538+
exclude some voxels from the flow calculation
539+
mask_only_for_patch_selection: Whether to only use mask to decide for
540+
which patch pairs to compute flow
541+
selection_mask_configs: MaskConfigs in text format specifying a mask the
542+
positive entries of which indicate locations for which flow should be
543+
computed; this mask should have the same resolution and geometry as the
544+
output flow volume
545+
min_peak_ratio: Quality threshold for acceptance of newly estimated flow
546+
vectors; see flow_utils.clean_flow
547+
min_peak_sharpness: Quality threshold for acceptance of newly estimated
548+
flow vectors; see flow_utils.clean_flow
549+
max_magnitude: Maximum magnitude of a flow vector; see
550+
flow_utils.clean_flow
551+
batch_size: Max number of patches to process in parallel
552+
image_volinfo: Path to the VolumeInfo descriptor of the image volume
553+
image_cache_bytes: Number of bytes to use for the in-memory image cache;
554+
this should ideally be large enough so that no chunks are loaded more
555+
than once when processing a subvolume
556+
mask_cache_bytes: Number of bytes to use for the in-memory mask cache
557+
search_radius: Additional radius to extend patch_size by in every
558+
direction when extracting data for the 'previous' section
559+
"""
560+
561+
patch_size: int
562+
stride: int
563+
delta_z: int
564+
max_delta_z: int
565+
max_attempts: int = 2
566+
mask_configs: str | mask_lib.MaskConfigs | None = None
567+
mask_only_for_patch_selection: bool = True
568+
selection_mask_configs: str | mask_lib.MaskConfigs | None = None
569+
min_peak_ratio: float = 1.6
570+
min_peak_sharpness: float = 1.6
571+
max_magnitude: int = 40
572+
batch_size: int = 1024
573+
image_volinfo: str | None = None
574+
image_cache_bytes: int = int(1e9)
575+
mask_cache_bytes: int = int(1e9)
576+
search_radius: int = 0
577+
578+
_config: EstimateMissingFlowConfig
579+
580+
def __init__(
581+
self,
582+
config: EstimateMissingFlowConfig,
583+
input_volinfo_or_ts_spec=None,
584+
):
585+
"""Constructor.
586+
587+
Args:
588+
config: Parameters for EstimateMissingFlow
589+
input_volinfo_or_ts_spec: unused
590+
"""
591+
del input_volinfo_or_ts_spec
592+
593+
self._config = config
594+
595+
if config.patch_size % config.stride != 0:
596+
raise ValueError(
597+
f'patch_size {config.patch_size} not a multiple of stride'
598+
f' {config.stride}'
599+
)
600+
601+
self._search_patch_size = config.patch_size + config.search_radius * 2
602+
if self._search_patch_size % config.stride != 0:
603+
raise ValueError(
604+
f'search_patch_size {self._search_patch_size} not a multiple of'
605+
f' stride {config.stride}'
606+
)
607+
608+
if config.mask_configs is not None:
609+
config.mask_configs = self._get_mask_configs(config.mask_configs)
610+
611+
if config.selection_mask_configs is not None:
612+
config.selection_mask_configs = self._get_mask_configs(
613+
config.selection_mask_configs
614+
)
615+
616+
def _get_mask_configs(self, mask_configs: str) -> mask_lib.MaskConfigs:
617+
raise NotImplementedError(
618+
'This function needs to be defined in a subclass.'
619+
)
620+
621+
def _open_volume(self, path: file.PathLike) -> base.Volume:
622+
raise NotImplementedError(
623+
'This function needs to be defined in a subclass.'
624+
)
625+
626+
def _build_mask(
627+
self,
628+
mask_configs: mask_lib.MaskConfigs,
629+
# TODO(blakely): Switch to BoundingBox after move to 3p.
630+
box: bounding_box.BoundingBoxBase,
631+
) -> Any:
632+
"""Returns a CZYX-shaped ndarray-like object."""
633+
raise NotImplementedError(
634+
'This function needs to be defined in a subclass.'
635+
)
636+
637+
def num_channels(self, input_channels):
638+
"""Returns the number of channels in the output volume.
639+
640+
Args:
641+
input_channels: The number of channels in the input volume.
642+
643+
Returned channels are `flow_x, flow_y, lookback_z`. The latter represents
644+
how far back in the stack the processor had to look to find a valid flow
645+
calculation.
646+
"""
647+
del input_channels
648+
return 3
649+
650+
def process(self, subvol: Subvolume) -> SubvolumeOrMany:
651+
box = subvol.bbox
652+
input_ndarray = subvol.data
653+
namespace = 'estimate-missing-flow'
654+
beam_utils.counter(namespace, 'subvolumes-started').inc()
655+
656+
image_volume = self._open_volume(self._config.image_volinfo)
657+
658+
# Bounding box identifying the region of the image for which the input
659+
# flow was computed.
660+
stride = self._config.stride
661+
full_image_box = bounding_box.BoundingBox(
662+
start=(
663+
box.start[0] * stride - self._search_patch_size // 2,
664+
box.start[1] * stride - self._search_patch_size // 2,
665+
box.start[2],
666+
),
667+
size=(
668+
(box.size[0] - 1) * stride + self._search_patch_size,
669+
(box.size[1] - 1) * stride + self._search_patch_size,
670+
1,
671+
),
672+
)
673+
prev_image_box = image_volume.clip_box_to_volume(full_image_box)
674+
assert prev_image_box is not None
675+
676+
# Nothing to do if we don't have sufficient image context for any
677+
# flow field entries.
678+
if np.any(prev_image_box.size[:2] <= self._search_patch_size):
679+
return subvol
680+
681+
# Do not process flow field entries for which we do not have sufficient
682+
# image context.
683+
offset = prev_image_box.translate(-full_image_box.start).start // stride
684+
out_box = box.adjusted_by(start=offset)
685+
input_ndarray = input_ndarray[:, :, offset[1] :, offset[0] :]
686+
687+
# ceil_div
688+
offset = -((prev_image_box.end - full_image_box.end) // stride)
689+
out_box = out_box.adjusted_by(end=-offset)
690+
input_ndarray = input_ndarray[:, :, : out_box.size[1], : out_box.size[0]]
691+
692+
patch_size = self._config.patch_size
693+
curr_image_box = bounding_box.BoundingBox(
694+
start=(
695+
out_box.start[0] * stride - patch_size // 2,
696+
out_box.start[1] * stride - patch_size // 2,
697+
out_box.start[2],
698+
),
699+
size=(
700+
(out_box.size[0] - 1) * stride + patch_size,
701+
(out_box.size[1] - 1) * stride + patch_size,
702+
1,
703+
),
704+
)
705+
curr_image_box = image_volume.clip_box_to_volume(curr_image_box)
706+
assert curr_image_box is not None
707+
708+
# The input flow forms the initial state of the output. We will try
709+
# to fill-in any invalid (NaN) pixels by computing flow against
710+
# earlier sections.
711+
ret = np.zeros([3] + list(out_box.size[::-1]))
712+
ret[:2, ...] = input_ndarray
713+
ret[2, ...] = self._config.delta_z
714+
715+
sel_mask = None
716+
if self._config.selection_mask_config is not None:
717+
sel_mask = self._build_mask(self._config.selection_mask_config, out_box)
718+
719+
mfc = flow_field.JAXMaskedXCorrWithStatsCalculator()
720+
invalid = np.isnan(input_ndarray[0, ...])
721+
for z in range(0, invalid.shape[0]):
722+
z0 = box.start[2] + z
723+
logging.info('Processing rel_z=%d abs_z=%d', z, z0)
724+
725+
if np.all(~invalid[z, ...]):
726+
beam_utils.counter(namespace, 'sections-already-valid').inc()
727+
continue
728+
729+
image_box = curr_image_box.translate([0, 0, z])
730+
curr_mask = None
731+
if self._config.mask_config is not None:
732+
curr_mask = self._build_mask(
733+
self._config.mask_config, image_box
734+
).squeeze()
735+
if np.all(curr_mask):
736+
beam_utils.counter(namespace, 'sections-masked').inc()
737+
continue
738+
739+
logging.info('Mask built.')
740+
741+
attempts = np.zeros(ret.shape[2:], dtype=int)
742+
mask = ~np.isfinite(ret[0, z, ...])
743+
if sel_mask is not None:
744+
mask &= sel_mask[z, ...]
745+
746+
curr = image_volume.asarray[image_box.to_slice4d()].squeeze()
747+
748+
delta_z = self._config.delta_z
749+
if delta_z > 0:
750+
rng = range(delta_z + 1, self._config.max_delta_z + 1)
751+
else:
752+
rng = range(delta_z - 1, self._config.max_delta_z - 1, -1)
753+
754+
for delta_z in rng:
755+
if (
756+
box.start[2] - delta_z < 0
757+
or box.end[2] - delta_z >= image_volume.volume_size[2]
758+
):
759+
break
760+
761+
t_start = time.time()
762+
prev_box = prev_image_box.translate([0, 0, z - delta_z])
763+
logging.info('Trying delta_z=%d (%r)', delta_z, prev_box)
764+
prev = image_volume.asarray[prev_box.to_slice4d()].squeeze()
765+
logging.info('.. image loaded.')
766+
t1 = time.time()
767+
768+
if self._config.mask_config is not None:
769+
prev_mask = self._build_mask(
770+
self._config.mask_config, prev_box
771+
).squeeze()
772+
if np.all(prev_mask):
773+
continue
774+
else:
775+
prev_mask = None
776+
logging.info('.. mask loaded.')
777+
778+
# Limit the number of estimation attempts per voxel. Attempts
779+
# are only counted when voxels in both sections are unmasked.
780+
mask &= attempts <= self._config.max_attempts
781+
if not np.any(mask):
782+
break
783+
784+
logging.info('.. points to evaluate: %d', np.sum(mask))
785+
t2 = time.time()
786+
787+
flow = mfc.flow_field(
788+
prev,
789+
curr,
790+
self._search_patch_size,
791+
self._config.stride,
792+
prev_mask,
793+
curr_mask,
794+
mask_only_for_patch_selection=self._config.mask_only_for_patch_selection,
795+
selection_mask=mask,
796+
batch_size=self._config.batch_size,
797+
post_patch_size=self._config.patch_size,
798+
)
799+
800+
t3 = time.time()
801+
valid = np.isfinite(flow[0, ...])
802+
attempts[: valid.shape[0], : valid.shape[1]][valid] += 1
803+
804+
flow = flow_utils.clean_flow(
805+
flow[:, np.newaxis, ...], #
806+
self._config.min_peak_ratio,
807+
self._config.min_peak_sharpness,
808+
self._config.max_magnitude,
809+
max_deviation=0.0,
810+
)
811+
812+
t4 = time.time()
813+
sy, sx = flow.shape[2:]
814+
to_update = mask[:sy, :sx] & np.isfinite(flow[0, 0, ...])
815+
mask[:sy, :sx][to_update] = False
816+
logging.info('.. points to update: %d', np.sum(to_update))
817+
818+
beam_utils.counter(namespace, f'sections-filled-delta{delta_z}').inc(
819+
np.sum(to_update)
820+
)
821+
ret[2, z, :sy, :sx][to_update] = delta_z
822+
ret[0, z, :sy, :sx][to_update] = flow[0, 0, ...][to_update]
823+
ret[1, z, :sy, :sx][to_update] = flow[1, 0, ...][to_update]
824+
t5 = time.time()
825+
826+
logging.info(
827+
'timings: img:%.2f mask:%.2f flow:%.2f clean:%.2f update:%.2f',
828+
t1 - t_start,
829+
t2 - t1,
830+
t3 - t2,
831+
t4 - t3,
832+
t5 - t4,
833+
)
834+
835+
return Subvolume(ret, out_box)

0 commit comments

Comments
 (0)