Skip to content

Commit 0becf87

Browse files
timblakelycopybara-github
authored andcommitted
Updates to EstimateFlow stages.
PiperOrigin-RevId: 667656706
1 parent 08b1cfe commit 0becf87

File tree

1 file changed

+82
-77
lines changed

1 file changed

+82
-77
lines changed

processor/flow.py

Lines changed: 82 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class EstimateFlowConfig(utils.NPDataClassJsonMixin):
8686
stride: int
8787
z_stride: int = 1
8888
fixed_current: bool = False
89-
mask_configs: mask_lib.MaskConfigs | None = None
89+
mask_configs: str | mask_lib.MaskConfigs | None = None
9090
mask_only_for_patch_selection: bool = False
9191
selection_mask_configs: mask_lib.MaskConfigs | None = None
9292
batch_size: int = 1024
@@ -103,9 +103,32 @@ def __init__(self, config: EstimateFlowConfig, input_volinfo_or_ts_spec=None):
103103

104104
del input_volinfo_or_ts_spec
105105
self._config = config
106-
107106
assert config.patch_size % config.stride == 0
108107

108+
if config.mask_configs is not None:
109+
if isinstance(config.mask_configs, str):
110+
config.mask_configs = self._get_mask_configs(config.mask_configs)
111+
112+
if config.selection_mask_configs is not None:
113+
if isinstance(config.selection_mask_configs, str):
114+
config.selection_mask_configs = self._get_mask_configs(
115+
config.selection_mask_configs
116+
)
117+
118+
def _get_mask_configs(self, mask_configs: str) -> mask_lib.MaskConfigs:
119+
raise NotImplementedError(
120+
'This function needs to be defined in a subclass.'
121+
)
122+
123+
def _build_mask(
124+
self,
125+
mask_configs: mask_lib.MaskConfigs,
126+
box: bounding_box.BoundingBoxBase,
127+
) -> Any:
128+
raise NotImplementedError(
129+
'This function needs to be defined in a subclass.'
130+
)
131+
109132
def output_type(self, input_type):
110133
return np.float32
111134

@@ -114,19 +137,18 @@ def subvolume_size(self):
114137
return subvolume_processor.SuggestedXyz(size, size, 16)
115138

116139
def context(self):
117-
config = self._config
118-
pre = config.patch_size // 2
119-
post = config.patch_size - pre
120-
if config.fixed_current:
121-
if config.z_stride > 0:
122-
return (pre, pre, 0), (post, post, config.z_stride)
140+
pre = self._config.patch_size // 2
141+
post = self._config.patch_size - pre
142+
if self._config.fixed_current:
143+
if self._config.z_stride > 0:
144+
return (pre, pre, 0), (post, post, self._config.z_stride)
123145
else:
124-
return (pre, pre, -config.z_stride), (post, post, 0)
146+
return (pre, pre, -self._config.z_stride), (post, post, 0)
125147
else:
126-
if config.z_stride > 0:
127-
return (pre, pre, config.z_stride), (post, post, 0)
148+
if self._config.z_stride > 0:
149+
return (pre, pre, self._config.z_stride), (post, post, 0)
128150
else:
129-
return (pre, pre, 0), (post, post, -config.z_stride)
151+
return (pre, pre, 0), (post, post, -self._config.z_stride)
130152

131153
def num_channels(self, input_channels):
132154
del input_channels
@@ -136,51 +158,37 @@ def num_channels(self, input_channels):
136158
)
137159

138160
def pixelsize(self, psize):
139-
psize = np.asarray(psize).copy().astype(np.float32)
161+
psize = psize.copy().astype(np.float32)
140162
psize[:2] *= self._config.stride
141163
return psize
142164

143-
def process(self, subvol: subvolume.Subvolume) -> subvolume.Subvolume:
144-
# TODO(blakely): Determine if Dask supports metrics, and if so, create a
145-
# shim that supports both Beam and Dask metrics.
146-
config = self._config
165+
def process(self, subvol: Subvolume) -> SubvolumeOrMany:
166+
box = subvol.bbox
167+
input_ndarray = subvol.data
168+
beam_utils.counter(self.namespace, 'subvolumes-started').inc()
147169

148-
assert subvol.data.shape[0], 'Input volume should have 1 channel.'
149-
image = subvol.data[0, ...]
150-
sel_mask = initial_mask = None
170+
assert input_ndarray.shape[0], 'Input volume should have 1 channel.'
171+
image = input_ndarray[0, ...]
172+
sel_mask = mask = None
151173

152-
if config.mask_configs is not None:
153-
# TODO(blakely): Remove the unused lambda here and below when the external
154-
# paths support DecoratorSpecs.
155-
initial_mask = mask_lib.build_mask(
156-
config.mask_configs, subvol.bbox, lambda x: x
157-
)
174+
with beam_utils.timer_counter(self.namespace, 'build-mask'):
175+
if self._config.mask_config is not None:
176+
mask = self._build_mask(self._config.mask_config, box)
158177

159-
if config.selection_mask_configs is not None:
160-
cropped_bbox = self.crop_box(subvol.bbox)
161-
sel_start = [
162-
cropped_bbox.start[0] / config.stride,
163-
cropped_bbox.start[1] / config.stride,
164-
subvol.bbox.start[2],
165-
]
166-
xy = np.array([1, 1, 0])
167-
scale = np.array([config.stride, config.stride, 1])
168-
sel_size = (
169-
subvol.bbox.size - xy * config.patch_size + xy * config.stride
170-
) / scale
171-
sel_box = bounding_box.BoundingBox(sel_start, sel_size)
172-
sel_mask = mask_lib.build_mask(
173-
config.selection_mask_configs, sel_box, lambda x: x
174-
)
178+
if self._config.selection_mask_config is not None:
179+
sel_box = box.scale(
180+
[1.0 / self._config.stride, 1.0 / self._config.stride, 1]
181+
)
182+
sel_mask = self._build_mask(self._config.selection_mask_config, sel_box)
175183

176184
def _estimate_flow(z_prev, z_curr):
177185
mask_prev = mask_curr = None
178186
prev = image[z_prev, ...]
179187
curr = image[z_curr, ...]
180188

181-
if initial_mask is not None:
182-
mask_prev = initial_mask[z_prev, ...]
183-
mask_curr = initial_mask[z_curr, ...]
189+
if mask is not None:
190+
mask_prev = mask[z_prev, ...]
191+
mask_curr = mask[z_curr, ...]
184192

185193
smask = None
186194
if sel_mask is not None:
@@ -189,55 +197,52 @@ def _estimate_flow(z_prev, z_curr):
189197
return mfc.flow_field(
190198
prev,
191199
curr,
192-
config.patch_size,
193-
config.stride,
200+
self._config.patch_size,
201+
self._config.stride,
194202
mask_prev,
195203
mask_curr,
196-
mask_only_for_patch_selection=config.mask_only_for_patch_selection,
204+
mask_only_for_patch_selection=self._config.mask_only_for_patch_selection,
197205
selection_mask=smask,
198-
batch_size=config.batch_size,
206+
batch_size=self._config.batch_size,
199207
)
200208

201-
mfc = flow_field.JAXMaskedXCorrWithStatsCalculator()
202-
flows = []
203-
204-
if config.fixed_current:
205-
if config.z_stride > 0:
206-
rng = range(0, image.shape[0] - 1)
207-
z_curr = image.shape[0] - 1
209+
with beam_utils.timer_counter(self.namespace, 'flow'):
210+
mfc = flow_field.JAXMaskedXCorrWithStatsCalculator()
211+
flows = []
212+
213+
if self._config.fixed_current:
214+
if self._config.z_stride > 0:
215+
rng = range(0, image.shape[0] - 1)
216+
z_curr = image.shape[0] - 1
217+
else:
218+
rng = range(1, image.shape[0])
219+
z_curr = 0
220+
for z_prev in rng:
221+
flows.append(_estimate_flow(z_prev, z_curr))
208222
else:
209-
rng = range(1, image.shape[0])
210-
z_curr = 0
211-
for z_prev in rng:
212-
flows.append(_estimate_flow(z_prev, z_curr))
213-
else:
214-
if config.z_stride > 0:
215-
rng = range(0, image.shape[0] - config.z_stride)
216-
else:
217-
rng = range(-config.z_stride, image.shape[0])
223+
if self._config.z_stride > 0:
224+
rng = range(0, image.shape[0] - self._config.z_stride)
225+
else:
226+
rng = range(-self._config.z_stride, image.shape[0])
218227

219-
for z in rng:
220-
flows.append(_estimate_flow(z, z + config.z_stride))
228+
for z in rng:
229+
flows.append(_estimate_flow(z, z + self._config.z_stride))
221230

222231
ret = np.array(flows)
223232

224233
# Output starts at:
225234
# Δz > 0: box.start.z + Δz
226235
# Δz < 0: box.start.z
227-
out_box = self.crop_box(subvol.bbox)
236+
out_box = self.crop_box(box)
228237
out_box = bounding_box.BoundingBox(
229-
start=out_box.start // [config.stride, config.stride, 1],
238+
start=out_box.start // [self._config.stride, self._config.stride, 1],
230239
size=[ret.shape[-1], ret.shape[-2], out_box.size[2]],
231240
)
241+
if ret.shape[0] != out_box.size[2]:
242+
raise ValueError(f'ret:{ret.shape} vs out:{out_box.size}')
232243

233-
expected_box = self.expected_output_box(subvol.bbox)
234-
if out_box != expected_box:
235-
raise ValueError(
236-
f'Bounding box does not match expected output_box {out_box} vs '
237-
f'{expected_box}'
238-
)
239-
240-
return subvolume.Subvolume(np.transpose(ret, (1, 0, 2, 3)), out_box)
244+
beam_utils.counter(self.namespace, 'subvolumes-done').inc()
245+
return Subvolume(np.transpose(ret, (1, 0, 2, 3)), out_box)
241246

242247
# Because mfc.flow_field does not take into account the standard subvolume
243248
# processor overlap schemes - the latter knows nothing about the internal

0 commit comments

Comments
 (0)