|
15 | 15 | """Flow field estimation from SOFIMA."""
|
16 | 16 |
|
17 | 17 | import dataclasses
|
| 18 | +import time |
18 | 19 | from typing import Any, Sequence
|
| 20 | + |
| 21 | +from absl import logging |
19 | 22 | from connectomics.common import beam_utils
|
20 | 23 | from connectomics.common import bounding_box
|
| 24 | +from connectomics.common import file |
21 | 25 | from connectomics.common import utils
|
22 | 26 | from connectomics.volume import base
|
23 | 27 | from connectomics.volume import mask as mask_lib
|
@@ -370,7 +374,7 @@ def __init__(
|
370 | 374 | if isinstance(config.mask_configs, str):
|
371 | 375 | config.mask_configs = self._get_mask_configs(config.mask_configs)
|
372 | 376 |
|
373 |
| - def _open_volume(self, path: str) -> base.Volume: |
| 377 | + def _open_volume(self, path: file.PathLike) -> base.Volume: |
374 | 378 | """Returns a CZYX-shaped ndarray-like object."""
|
375 | 379 | raise NotImplementedError(
|
376 | 380 | 'This function needs to be defined in a subclass.'
|
@@ -508,3 +512,324 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
|
508 | 512 | self._config.min_patch_size,
|
509 | 513 | )
|
510 | 514 | return self.crop_box_and_data(box, ret)
|
| 515 | + |
| 516 | + |
| 517 | +class EstimateMissingFlow(subvolume_processor.SubvolumeProcessor): |
| 518 | + """Estimates a multi-section flow field. |
| 519 | +
|
| 520 | + Takes an existing single-section (2-channel) flow volume as input, |
| 521 | + and tries to compute flow vectors which are invalid in the input (NaNs). |
| 522 | + """ |
| 523 | + |
| 524 | + @dataclasses_json.dataclass_json |
| 525 | + @dataclasses.dataclass(frozen=True) |
| 526 | + class EstimateMissingFlowConfig: |
| 527 | + """Configuration for EstimateMissingFlow. |
| 528 | +
|
| 529 | + Attributes: |
| 530 | + patch_size: Patch size in pixels, divisible by 'stride' |
| 531 | + stride: XY stride size in pixels |
| 532 | + delta_z: Z stride size in pixels (Δz) for the input volume |
| 533 | + max_delta_z: Maximum Z stride with which try to estimate missing flow |
| 534 | + vectors |
| 535 | + max_attempts: Maximum number of attempts to estimate a flow vector for an |
| 536 | + unmasked location |
| 537 | + mask_configs: MaskConfigs proto in text format specifying a mask to |
| 538 | + exclude some voxels from the flow calculation |
| 539 | + mask_only_for_patch_selection: Whether to only use mask to decide for |
| 540 | + which patch pairs to compute flow |
| 541 | + selection_mask_configs: MaskConfigs in text format specifying a mask the |
| 542 | + positive entries of which indicate locations for which flow should be |
| 543 | + computed; this mask should have the same resolution and geometry as the |
| 544 | + output flow volume |
| 545 | + min_peak_ratio: Quality threshold for acceptance of newly estimated flow |
| 546 | + vectors; see flow_utils.clean_flow |
| 547 | + min_peak_sharpness: Quality threshold for acceptance of newly estimated |
| 548 | + flow vectors; see flow_utils.clean_flow |
| 549 | + max_magnitude: Maximum magnitude of a flow vector; see |
| 550 | + flow_utils.clean_flow |
| 551 | + batch_size: Max number of patches to process in parallel |
| 552 | + image_volinfo: Path to the VolumeInfo descriptor of the image volume |
| 553 | + image_cache_bytes: Number of bytes to use for the in-memory image cache; |
| 554 | + this should ideally be large enough so that no chunks are loaded more |
| 555 | + than once when processing a subvolume |
| 556 | + mask_cache_bytes: Number of bytes to use for the in-memory mask cache |
| 557 | + search_radius: Additional radius to extend patch_size by in every |
| 558 | + direction when extracting data for the 'previous' section |
| 559 | + """ |
| 560 | + |
| 561 | + patch_size: int |
| 562 | + stride: int |
| 563 | + delta_z: int |
| 564 | + max_delta_z: int |
| 565 | + max_attempts: int = 2 |
| 566 | + mask_configs: str | mask_lib.MaskConfigs | None = None |
| 567 | + mask_only_for_patch_selection: bool = True |
| 568 | + selection_mask_configs: str | mask_lib.MaskConfigs | None = None |
| 569 | + min_peak_ratio: float = 1.6 |
| 570 | + min_peak_sharpness: float = 1.6 |
| 571 | + max_magnitude: int = 40 |
| 572 | + batch_size: int = 1024 |
| 573 | + image_volinfo: str | None = None |
| 574 | + image_cache_bytes: int = int(1e9) |
| 575 | + mask_cache_bytes: int = int(1e9) |
| 576 | + search_radius: int = 0 |
| 577 | + |
| 578 | + _config: EstimateMissingFlowConfig |
| 579 | + |
| 580 | + def __init__( |
| 581 | + self, |
| 582 | + config: EstimateMissingFlowConfig, |
| 583 | + input_volinfo_or_ts_spec=None, |
| 584 | + ): |
| 585 | + """Constructor. |
| 586 | +
|
| 587 | + Args: |
| 588 | + config: Parameters for EstimateMissingFlow |
| 589 | + input_volinfo_or_ts_spec: unused |
| 590 | + """ |
| 591 | + del input_volinfo_or_ts_spec |
| 592 | + |
| 593 | + self._config = config |
| 594 | + |
| 595 | + if config.patch_size % config.stride != 0: |
| 596 | + raise ValueError( |
| 597 | + f'patch_size {config.patch_size} not a multiple of stride' |
| 598 | + f' {config.stride}' |
| 599 | + ) |
| 600 | + |
| 601 | + self._search_patch_size = config.patch_size + config.search_radius * 2 |
| 602 | + if self._search_patch_size % config.stride != 0: |
| 603 | + raise ValueError( |
| 604 | + f'search_patch_size {self._search_patch_size} not a multiple of' |
| 605 | + f' stride {config.stride}' |
| 606 | + ) |
| 607 | + |
| 608 | + if config.mask_configs is not None: |
| 609 | + config.mask_configs = self._get_mask_configs(config.mask_configs) |
| 610 | + |
| 611 | + if config.selection_mask_configs is not None: |
| 612 | + config.selection_mask_configs = self._get_mask_configs( |
| 613 | + config.selection_mask_configs |
| 614 | + ) |
| 615 | + |
| 616 | + def _get_mask_configs(self, mask_configs: str) -> mask_lib.MaskConfigs: |
| 617 | + raise NotImplementedError( |
| 618 | + 'This function needs to be defined in a subclass.' |
| 619 | + ) |
| 620 | + |
| 621 | + def _open_volume(self, path: file.PathLike) -> base.Volume: |
| 622 | + raise NotImplementedError( |
| 623 | + 'This function needs to be defined in a subclass.' |
| 624 | + ) |
| 625 | + |
| 626 | + def _build_mask( |
| 627 | + self, |
| 628 | + mask_configs: mask_lib.MaskConfigs, |
| 629 | + # TODO(blakely): Switch to BoundingBox after move to 3p. |
| 630 | + box: bounding_box.BoundingBoxBase, |
| 631 | + ) -> Any: |
| 632 | + """Returns a CZYX-shaped ndarray-like object.""" |
| 633 | + raise NotImplementedError( |
| 634 | + 'This function needs to be defined in a subclass.' |
| 635 | + ) |
| 636 | + |
| 637 | + def num_channels(self, input_channels): |
| 638 | + """Returns the number of channels in the output volume. |
| 639 | +
|
| 640 | + Args: |
| 641 | + input_channels: The number of channels in the input volume. |
| 642 | +
|
| 643 | + Returned channels are `flow_x, flow_y, lookback_z`. The latter represents |
| 644 | + how far back in the stack the processor had to look to find a valid flow |
| 645 | + calculation. |
| 646 | + """ |
| 647 | + del input_channels |
| 648 | + return 3 |
| 649 | + |
| 650 | + def process(self, subvol: Subvolume) -> SubvolumeOrMany: |
| 651 | + box = subvol.bbox |
| 652 | + input_ndarray = subvol.data |
| 653 | + namespace = 'estimate-missing-flow' |
| 654 | + beam_utils.counter(namespace, 'subvolumes-started').inc() |
| 655 | + |
| 656 | + image_volume = self._open_volume(self._config.image_volinfo) |
| 657 | + |
| 658 | + # Bounding box identifying the region of the image for which the input |
| 659 | + # flow was computed. |
| 660 | + stride = self._config.stride |
| 661 | + full_image_box = bounding_box.BoundingBox( |
| 662 | + start=( |
| 663 | + box.start[0] * stride - self._search_patch_size // 2, |
| 664 | + box.start[1] * stride - self._search_patch_size // 2, |
| 665 | + box.start[2], |
| 666 | + ), |
| 667 | + size=( |
| 668 | + (box.size[0] - 1) * stride + self._search_patch_size, |
| 669 | + (box.size[1] - 1) * stride + self._search_patch_size, |
| 670 | + 1, |
| 671 | + ), |
| 672 | + ) |
| 673 | + prev_image_box = image_volume.clip_box_to_volume(full_image_box) |
| 674 | + assert prev_image_box is not None |
| 675 | + |
| 676 | + # Nothing to do if we don't have sufficient image context for any |
| 677 | + # flow field entries. |
| 678 | + if np.any(prev_image_box.size[:2] <= self._search_patch_size): |
| 679 | + return subvol |
| 680 | + |
| 681 | + # Do not process flow field entries for which we do not have sufficient |
| 682 | + # image context. |
| 683 | + offset = prev_image_box.translate(-full_image_box.start).start // stride |
| 684 | + out_box = box.adjusted_by(start=offset) |
| 685 | + input_ndarray = input_ndarray[:, :, offset[1] :, offset[0] :] |
| 686 | + |
| 687 | + # ceil_div |
| 688 | + offset = -((prev_image_box.end - full_image_box.end) // stride) |
| 689 | + out_box = out_box.adjusted_by(end=-offset) |
| 690 | + input_ndarray = input_ndarray[:, :, : out_box.size[1], : out_box.size[0]] |
| 691 | + |
| 692 | + patch_size = self._config.patch_size |
| 693 | + curr_image_box = bounding_box.BoundingBox( |
| 694 | + start=( |
| 695 | + out_box.start[0] * stride - patch_size // 2, |
| 696 | + out_box.start[1] * stride - patch_size // 2, |
| 697 | + out_box.start[2], |
| 698 | + ), |
| 699 | + size=( |
| 700 | + (out_box.size[0] - 1) * stride + patch_size, |
| 701 | + (out_box.size[1] - 1) * stride + patch_size, |
| 702 | + 1, |
| 703 | + ), |
| 704 | + ) |
| 705 | + curr_image_box = image_volume.clip_box_to_volume(curr_image_box) |
| 706 | + assert curr_image_box is not None |
| 707 | + |
| 708 | + # The input flow forms the initial state of the output. We will try |
| 709 | + # to fill-in any invalid (NaN) pixels by computing flow against |
| 710 | + # earlier sections. |
| 711 | + ret = np.zeros([3] + list(out_box.size[::-1])) |
| 712 | + ret[:2, ...] = input_ndarray |
| 713 | + ret[2, ...] = self._config.delta_z |
| 714 | + |
| 715 | + sel_mask = None |
| 716 | + if self._config.selection_mask_config is not None: |
| 717 | + sel_mask = self._build_mask(self._config.selection_mask_config, out_box) |
| 718 | + |
| 719 | + mfc = flow_field.JAXMaskedXCorrWithStatsCalculator() |
| 720 | + invalid = np.isnan(input_ndarray[0, ...]) |
| 721 | + for z in range(0, invalid.shape[0]): |
| 722 | + z0 = box.start[2] + z |
| 723 | + logging.info('Processing rel_z=%d abs_z=%d', z, z0) |
| 724 | + |
| 725 | + if np.all(~invalid[z, ...]): |
| 726 | + beam_utils.counter(namespace, 'sections-already-valid').inc() |
| 727 | + continue |
| 728 | + |
| 729 | + image_box = curr_image_box.translate([0, 0, z]) |
| 730 | + curr_mask = None |
| 731 | + if self._config.mask_config is not None: |
| 732 | + curr_mask = self._build_mask( |
| 733 | + self._config.mask_config, image_box |
| 734 | + ).squeeze() |
| 735 | + if np.all(curr_mask): |
| 736 | + beam_utils.counter(namespace, 'sections-masked').inc() |
| 737 | + continue |
| 738 | + |
| 739 | + logging.info('Mask built.') |
| 740 | + |
| 741 | + attempts = np.zeros(ret.shape[2:], dtype=int) |
| 742 | + mask = ~np.isfinite(ret[0, z, ...]) |
| 743 | + if sel_mask is not None: |
| 744 | + mask &= sel_mask[z, ...] |
| 745 | + |
| 746 | + curr = image_volume.asarray[image_box.to_slice4d()].squeeze() |
| 747 | + |
| 748 | + delta_z = self._config.delta_z |
| 749 | + if delta_z > 0: |
| 750 | + rng = range(delta_z + 1, self._config.max_delta_z + 1) |
| 751 | + else: |
| 752 | + rng = range(delta_z - 1, self._config.max_delta_z - 1, -1) |
| 753 | + |
| 754 | + for delta_z in rng: |
| 755 | + if ( |
| 756 | + box.start[2] - delta_z < 0 |
| 757 | + or box.end[2] - delta_z >= image_volume.volume_size[2] |
| 758 | + ): |
| 759 | + break |
| 760 | + |
| 761 | + t_start = time.time() |
| 762 | + prev_box = prev_image_box.translate([0, 0, z - delta_z]) |
| 763 | + logging.info('Trying delta_z=%d (%r)', delta_z, prev_box) |
| 764 | + prev = image_volume.asarray[prev_box.to_slice4d()].squeeze() |
| 765 | + logging.info('.. image loaded.') |
| 766 | + t1 = time.time() |
| 767 | + |
| 768 | + if self._config.mask_config is not None: |
| 769 | + prev_mask = self._build_mask( |
| 770 | + self._config.mask_config, prev_box |
| 771 | + ).squeeze() |
| 772 | + if np.all(prev_mask): |
| 773 | + continue |
| 774 | + else: |
| 775 | + prev_mask = None |
| 776 | + logging.info('.. mask loaded.') |
| 777 | + |
| 778 | + # Limit the number of estimation attempts per voxel. Attempts |
| 779 | + # are only counted when voxels in both sections are unmasked. |
| 780 | + mask &= attempts <= self._config.max_attempts |
| 781 | + if not np.any(mask): |
| 782 | + break |
| 783 | + |
| 784 | + logging.info('.. points to evaluate: %d', np.sum(mask)) |
| 785 | + t2 = time.time() |
| 786 | + |
| 787 | + flow = mfc.flow_field( |
| 788 | + prev, |
| 789 | + curr, |
| 790 | + self._search_patch_size, |
| 791 | + self._config.stride, |
| 792 | + prev_mask, |
| 793 | + curr_mask, |
| 794 | + mask_only_for_patch_selection=self._config.mask_only_for_patch_selection, |
| 795 | + selection_mask=mask, |
| 796 | + batch_size=self._config.batch_size, |
| 797 | + post_patch_size=self._config.patch_size, |
| 798 | + ) |
| 799 | + |
| 800 | + t3 = time.time() |
| 801 | + valid = np.isfinite(flow[0, ...]) |
| 802 | + attempts[: valid.shape[0], : valid.shape[1]][valid] += 1 |
| 803 | + |
| 804 | + flow = flow_utils.clean_flow( |
| 805 | + flow[:, np.newaxis, ...], # |
| 806 | + self._config.min_peak_ratio, |
| 807 | + self._config.min_peak_sharpness, |
| 808 | + self._config.max_magnitude, |
| 809 | + max_deviation=0.0, |
| 810 | + ) |
| 811 | + |
| 812 | + t4 = time.time() |
| 813 | + sy, sx = flow.shape[2:] |
| 814 | + to_update = mask[:sy, :sx] & np.isfinite(flow[0, 0, ...]) |
| 815 | + mask[:sy, :sx][to_update] = False |
| 816 | + logging.info('.. points to update: %d', np.sum(to_update)) |
| 817 | + |
| 818 | + beam_utils.counter(namespace, f'sections-filled-delta{delta_z}').inc( |
| 819 | + np.sum(to_update) |
| 820 | + ) |
| 821 | + ret[2, z, :sy, :sx][to_update] = delta_z |
| 822 | + ret[0, z, :sy, :sx][to_update] = flow[0, 0, ...][to_update] |
| 823 | + ret[1, z, :sy, :sx][to_update] = flow[1, 0, ...][to_update] |
| 824 | + t5 = time.time() |
| 825 | + |
| 826 | + logging.info( |
| 827 | + 'timings: img:%.2f mask:%.2f flow:%.2f clean:%.2f update:%.2f', |
| 828 | + t1 - t_start, |
| 829 | + t2 - t1, |
| 830 | + t3 - t2, |
| 831 | + t4 - t3, |
| 832 | + t5 - t4, |
| 833 | + ) |
| 834 | + |
| 835 | + return Subvolume(ret, out_box) |
0 commit comments