@@ -86,7 +86,7 @@ class EstimateFlowConfig(utils.NPDataClassJsonMixin):
86
86
stride : int
87
87
z_stride : int = 1
88
88
fixed_current : bool = False
89
- mask_configs : mask_lib .MaskConfigs | None = None
89
+ mask_configs : str | mask_lib .MaskConfigs | None = None
90
90
mask_only_for_patch_selection : bool = False
91
91
selection_mask_configs : mask_lib .MaskConfigs | None = None
92
92
batch_size : int = 1024
@@ -103,9 +103,32 @@ def __init__(self, config: EstimateFlowConfig, input_volinfo_or_ts_spec=None):
103
103
104
104
del input_volinfo_or_ts_spec
105
105
self ._config = config
106
-
107
106
assert config .patch_size % config .stride == 0
108
107
108
+ if config .mask_configs is not None :
109
+ if isinstance (config .mask_configs , str ):
110
+ config .mask_configs = self ._get_mask_configs (config .mask_configs )
111
+
112
+ if config .selection_mask_configs is not None :
113
+ if isinstance (config .selection_mask_configs , str ):
114
+ config .selection_mask_configs = self ._get_mask_configs (
115
+ config .selection_mask_configs
116
+ )
117
+
118
+ def _get_mask_configs (self , mask_configs : str ) -> mask_lib .MaskConfigs :
119
+ raise NotImplementedError (
120
+ 'This function needs to be defined in a subclass.'
121
+ )
122
+
123
+ def _build_mask (
124
+ self ,
125
+ mask_configs : mask_lib .MaskConfigs ,
126
+ box : bounding_box .BoundingBoxBase ,
127
+ ) -> Any :
128
+ raise NotImplementedError (
129
+ 'This function needs to be defined in a subclass.'
130
+ )
131
+
109
132
def output_type (self , input_type ):
110
133
return np .float32
111
134
@@ -114,19 +137,18 @@ def subvolume_size(self):
114
137
return subvolume_processor .SuggestedXyz (size , size , 16 )
115
138
116
139
def context (self ):
117
- config = self ._config
118
- pre = config .patch_size // 2
119
- post = config .patch_size - pre
120
- if config .fixed_current :
121
- if config .z_stride > 0 :
122
- return (pre , pre , 0 ), (post , post , config .z_stride )
140
+ pre = self ._config .patch_size // 2
141
+ post = self ._config .patch_size - pre
142
+ if self ._config .fixed_current :
143
+ if self ._config .z_stride > 0 :
144
+ return (pre , pre , 0 ), (post , post , self ._config .z_stride )
123
145
else :
124
- return (pre , pre , - config .z_stride ), (post , post , 0 )
146
+ return (pre , pre , - self . _config .z_stride ), (post , post , 0 )
125
147
else :
126
- if config .z_stride > 0 :
127
- return (pre , pre , config .z_stride ), (post , post , 0 )
148
+ if self . _config .z_stride > 0 :
149
+ return (pre , pre , self . _config .z_stride ), (post , post , 0 )
128
150
else :
129
- return (pre , pre , 0 ), (post , post , - config .z_stride )
151
+ return (pre , pre , 0 ), (post , post , - self . _config .z_stride )
130
152
131
153
def num_channels (self , input_channels ):
132
154
del input_channels
@@ -136,51 +158,37 @@ def num_channels(self, input_channels):
136
158
)
137
159
138
160
def pixelsize (self , psize ):
139
- psize = np . asarray ( psize ) .copy ().astype (np .float32 )
161
+ psize = psize .copy ().astype (np .float32 )
140
162
psize [:2 ] *= self ._config .stride
141
163
return psize
142
164
143
- def process (self , subvol : subvolume . Subvolume ) -> subvolume . Subvolume :
144
- # TODO(blakely): Determine if Dask supports metrics, and if so, create a
145
- # shim that supports both Beam and Dask metrics.
146
- config = self ._config
165
+ def process (self , subvol : Subvolume ) -> SubvolumeOrMany :
166
+ box = subvol . bbox
167
+ input_ndarray = subvol . data
168
+ beam_utils . counter ( self .namespace , 'subvolumes-started' ). inc ()
147
169
148
- assert subvol . data .shape [0 ], 'Input volume should have 1 channel.'
149
- image = subvol . data [0 , ...]
150
- sel_mask = initial_mask = None
170
+ assert input_ndarray .shape [0 ], 'Input volume should have 1 channel.'
171
+ image = input_ndarray [0 , ...]
172
+ sel_mask = mask = None
151
173
152
- if config .mask_configs is not None :
153
- # TODO(blakely): Remove the unused lambda here and below when the external
154
- # paths support DecoratorSpecs.
155
- initial_mask = mask_lib .build_mask (
156
- config .mask_configs , subvol .bbox , lambda x : x
157
- )
174
+ with beam_utils .timer_counter (self .namespace , 'build-mask' ):
175
+ if self ._config .mask_config is not None :
176
+ mask = self ._build_mask (self ._config .mask_config , box )
158
177
159
- if config .selection_mask_configs is not None :
160
- cropped_bbox = self .crop_box (subvol .bbox )
161
- sel_start = [
162
- cropped_bbox .start [0 ] / config .stride ,
163
- cropped_bbox .start [1 ] / config .stride ,
164
- subvol .bbox .start [2 ],
165
- ]
166
- xy = np .array ([1 , 1 , 0 ])
167
- scale = np .array ([config .stride , config .stride , 1 ])
168
- sel_size = (
169
- subvol .bbox .size - xy * config .patch_size + xy * config .stride
170
- ) / scale
171
- sel_box = bounding_box .BoundingBox (sel_start , sel_size )
172
- sel_mask = mask_lib .build_mask (
173
- config .selection_mask_configs , sel_box , lambda x : x
174
- )
178
+ if self ._config .selection_mask_config is not None :
179
+ sel_box = box .scale (
180
+ [1.0 / self ._config .stride , 1.0 / self ._config .stride , 1 ]
181
+ )
182
+ sel_mask = self ._build_mask (self ._config .selection_mask_config , sel_box )
175
183
176
184
def _estimate_flow (z_prev , z_curr ):
177
185
mask_prev = mask_curr = None
178
186
prev = image [z_prev , ...]
179
187
curr = image [z_curr , ...]
180
188
181
- if initial_mask is not None :
182
- mask_prev = initial_mask [z_prev , ...]
183
- mask_curr = initial_mask [z_curr , ...]
189
+ if mask is not None :
190
+ mask_prev = mask [z_prev , ...]
191
+ mask_curr = mask [z_curr , ...]
184
192
185
193
smask = None
186
194
if sel_mask is not None :
@@ -189,55 +197,52 @@ def _estimate_flow(z_prev, z_curr):
189
197
return mfc .flow_field (
190
198
prev ,
191
199
curr ,
192
- config .patch_size ,
193
- config .stride ,
200
+ self . _config .patch_size ,
201
+ self . _config .stride ,
194
202
mask_prev ,
195
203
mask_curr ,
196
- mask_only_for_patch_selection = config .mask_only_for_patch_selection ,
204
+ mask_only_for_patch_selection = self . _config .mask_only_for_patch_selection ,
197
205
selection_mask = smask ,
198
- batch_size = config .batch_size ,
206
+ batch_size = self . _config .batch_size ,
199
207
)
200
208
201
- mfc = flow_field .JAXMaskedXCorrWithStatsCalculator ()
202
- flows = []
203
-
204
- if config .fixed_current :
205
- if config .z_stride > 0 :
206
- rng = range (0 , image .shape [0 ] - 1 )
207
- z_curr = image .shape [0 ] - 1
209
+ with beam_utils .timer_counter (self .namespace , 'flow' ):
210
+ mfc = flow_field .JAXMaskedXCorrWithStatsCalculator ()
211
+ flows = []
212
+
213
+ if self ._config .fixed_current :
214
+ if self ._config .z_stride > 0 :
215
+ rng = range (0 , image .shape [0 ] - 1 )
216
+ z_curr = image .shape [0 ] - 1
217
+ else :
218
+ rng = range (1 , image .shape [0 ])
219
+ z_curr = 0
220
+ for z_prev in rng :
221
+ flows .append (_estimate_flow (z_prev , z_curr ))
208
222
else :
209
- rng = range (1 , image .shape [0 ])
210
- z_curr = 0
211
- for z_prev in rng :
212
- flows .append (_estimate_flow (z_prev , z_curr ))
213
- else :
214
- if config .z_stride > 0 :
215
- rng = range (0 , image .shape [0 ] - config .z_stride )
216
- else :
217
- rng = range (- config .z_stride , image .shape [0 ])
223
+ if self ._config .z_stride > 0 :
224
+ rng = range (0 , image .shape [0 ] - self ._config .z_stride )
225
+ else :
226
+ rng = range (- self ._config .z_stride , image .shape [0 ])
218
227
219
- for z in rng :
220
- flows .append (_estimate_flow (z , z + config .z_stride ))
228
+ for z in rng :
229
+ flows .append (_estimate_flow (z , z + self . _config .z_stride ))
221
230
222
231
ret = np .array (flows )
223
232
224
233
# Output starts at:
225
234
# Δz > 0: box.start.z + Δz
226
235
# Δz < 0: box.start.z
227
- out_box = self .crop_box (subvol . bbox )
236
+ out_box = self .crop_box (box )
228
237
out_box = bounding_box .BoundingBox (
229
- start = out_box .start // [config . stride , config .stride , 1 ],
238
+ start = out_box .start // [self . _config . stride , self . _config .stride , 1 ],
230
239
size = [ret .shape [- 1 ], ret .shape [- 2 ], out_box .size [2 ]],
231
240
)
241
+ if ret .shape [0 ] != out_box .size [2 ]:
242
+ raise ValueError (f'ret:{ ret .shape } vs out:{ out_box .size } ' )
232
243
233
- expected_box = self .expected_output_box (subvol .bbox )
234
- if out_box != expected_box :
235
- raise ValueError (
236
- f'Bounding box does not match expected output_box { out_box } vs '
237
- f'{ expected_box } '
238
- )
239
-
240
- return subvolume .Subvolume (np .transpose (ret , (1 , 0 , 2 , 3 )), out_box )
244
+ beam_utils .counter (self .namespace , 'subvolumes-done' ).inc ()
245
+ return Subvolume (np .transpose (ret , (1 , 0 , 2 , 3 )), out_box )
241
246
242
247
# Because mfc.flow_field does not take into account the standard subvolume
243
248
# processor overlap schemes - the latter knows nothing about the internal
0 commit comments