@@ -39,10 +39,7 @@ def _get_image_id(image_path: str) -> str:
39
39
40
40
41
41
def apply_stamp_to_stacc_gt_label (
42
- stamp : np .ndarray ,
43
- position : Tuple [int , int ],
44
- label_matrix : np .ndarray ,
45
- image_id : str
42
+ stamp : np .ndarray , position : Tuple [int , int ], label_matrix : np .ndarray , image_id : str
46
43
) -> None :
47
44
"""Add a stamp (small matrix) at the coordinates 'position' of the label_matrix.
48
45
@@ -123,7 +120,7 @@ def create_gaussian_stamp(width: int, lower_bound: float, upper_bound: float, ep
123
120
124
121
stamp = np .zeros ((width , width ))
125
122
stamp [width // 2 , width // 2 ] = 1
126
- stamp = gaussian (stamp , sigma = sigma , truncate = 10.0 , mode = ' constant' )
123
+ stamp = gaussian (stamp , sigma = sigma , truncate = 10.0 , mode = " constant" )
127
124
stamp [np .where (stamp < eps )] = 0 # Truncate to make the stamp circular
128
125
stamp = stamp * 2 * 4 * np .pi * sigma ** 2
129
126
@@ -139,16 +136,16 @@ def create_stacc_ground_truth_label_from_json(
139
136
upper_bound : Optional [float ] = None ,
140
137
bounding_box : Optional [Tuple [slice ]] = None ,
141
138
):
142
- """Create a ground truth label matrix from JSON bounding box annotations.
139
+ """Create a ground truth label matrix from JSON point annotations.
143
140
144
141
Args:
145
142
image_path: Path to the input image.
146
- label_path: Path to the corresponding JSON file containing bounding box labels.
143
+ label_path: Path to the corresponding JSON file containing point labels.
147
144
eps: Epsilon value for Gaussian truncation. Default is 0.00001.
148
145
sigma: Sigma value for Gaussian blur. If None, individual stamps are applied.
149
146
lower_bound: Lower bound for Gaussian sigma. If None, no lower bound for sigma will be set.
150
147
upper_bound: Upper bound for Gaussian sigma If None, no upper bound for sigma will be set.
151
- bounding_box: A bounding box for creating he labels only in a sub-part of the image.
148
+ bounding_box: A bounding box for creating the labels only in a sub-part of the image.
152
149
153
150
Returns:
154
151
A 2D array representing the stacc ground truth label matrix.
@@ -158,9 +155,9 @@ def create_stacc_ground_truth_label_from_json(
158
155
159
156
bboxes = label_dict ["labels" ]
160
157
n_colonies = len (bboxes ) # Number of annotations / bounding boxes
158
+ image_id = _get_image_id (image_path ) # Get image id and check if jpg or tif image
161
159
162
160
if bounding_box is None :
163
- image_id = _get_image_id (image_path ) # Get image id and check if jpg or tif image
164
161
im = imread (image_path )
165
162
n_rows , n_columns = im .shape [:2 ]
166
163
else :
@@ -171,17 +168,9 @@ def create_stacc_ground_truth_label_from_json(
171
168
stacc_gt_label = np .zeros ((n_rows , n_columns )) # Create empty ground truth label
172
169
if n_colonies > 0 :
173
170
# Only keep the stacc_gt_label that are inside the image dimensions
174
- reasonable_indices = [
175
- i for i in range (n_colonies ) if
176
- (bboxes [i ]["x" ] + max (int (bboxes [i ]["width" ]/ 2 ), 1 )) <= n_columns and
177
- (bboxes [i ]["y" ] + max (int (bboxes [i ]["height" ]/ 2 ), 1 )) <= n_rows
178
- ]
179
- x_coordinates = np .array (
180
- [(int (bboxes [i ]["y" ]) + max (int (bboxes [i ]["width" ]/ 2 ), 1 )) for i in reasonable_indices ], dtype = "int"
181
- )
182
- y_coordinates = np .array (
183
- [(int (bboxes [i ]["x" ]) + max (int (bboxes [i ]["height" ]/ 2 ), 1 )) for i in reasonable_indices ], dtype = "int"
184
- )
171
+ reasonable_indices = [i for i in range (n_colonies ) if bboxes [i ]["x" ] < n_columns and bboxes [i ]["y" ] < n_rows ]
172
+ x_coordinates = np .array ([int (bboxes [i ]["y" ]) for i in reasonable_indices ], dtype = "int" )
173
+ y_coordinates = np .array ([int (bboxes [i ]["x" ]) for i in reasonable_indices ], dtype = "int" )
185
174
186
175
# If we have a bounding box we need to subtract the lower corner and filter out points that are outside.
187
176
if bounding_box is not None :
@@ -201,8 +190,8 @@ def create_stacc_ground_truth_label_from_json(
201
190
else :
202
191
# Process each coordinate individually
203
192
for i , (x_coord , y_coord ) in enumerate (zip (x_coordinates , y_coordinates )):
204
- width = max (int (bboxes [reasonable_indices [i ]][' width' ]), 1 )
205
- height = max (int (bboxes [reasonable_indices [i ]][' height' ]), 1 )
193
+ width = max (int (bboxes [reasonable_indices [i ]][" width" ]), 1 )
194
+ height = max (int (bboxes [reasonable_indices [i ]][" height" ]), 1 )
206
195
width = min (width , height ) # Make the stamp square
207
196
if width % 2 == 0 :
208
197
width -= 1
@@ -267,8 +256,8 @@ def create_stacc_labels_from_csv(
267
256
268
257
269
258
class StaccImageCollectionDataset (torch .utils .data .Dataset ):
270
- """@private
271
- """
259
+ """@private"""
260
+
272
261
max_sampling_attempts = 500
273
262
274
263
def _check_inputs (self , raw_images , label_images ):
@@ -337,10 +326,7 @@ def _sample_bounding_box(self, shape):
337
326
raise NotImplementedError (
338
327
f"Image padding is not supported yet. Data shape { shape } , patch shape { self .patch_shape } "
339
328
)
340
- bb_start = [
341
- np .random .randint (0 , sh - psh ) if sh - psh > 0 else 0
342
- for sh , psh in zip (shape , self .patch_shape )
343
- ]
329
+ bb_start = [np .random .randint (0 , sh - psh ) if sh - psh > 0 else 0 for sh , psh in zip (shape , self .patch_shape )]
344
330
return tuple (slice (start , start + psh ) for start , psh in zip (bb_start , self .patch_shape ))
345
331
346
332
def _get_sample (self , index ):
@@ -358,9 +344,13 @@ def _get_sample(self, index):
358
344
_ , label_extension = os .path .splitext (label_path )
359
345
if label_extension == ".json" :
360
346
label_patch = create_stacc_ground_truth_label_from_json (
361
- raw_path , label_path , eps = self .eps , sigma = self .sigma ,
362
- lower_bound = self .lower_bound , upper_bound = self .upper_bound ,
363
- bounding_box = bb
347
+ raw_path ,
348
+ label_path ,
349
+ eps = self .eps ,
350
+ sigma = self .sigma ,
351
+ lower_bound = self .lower_bound ,
352
+ upper_bound = self .upper_bound ,
353
+ bounding_box = bb ,
364
354
)
365
355
elif label_extension .lower () == ".csv" :
366
356
if self .sigma is None :
@@ -383,7 +373,7 @@ def _get_sample(self, index):
383
373
shape = shape [:- 1 ]
384
374
else :
385
375
shape = shape [1 :]
386
- prefix_box = (slice (None ), )
376
+ prefix_box = (slice (None ),)
387
377
388
378
raw_patch = np .array (raw [prefix_box + bb ])
389
379
@@ -485,18 +475,33 @@ def get_stacc_data_loader(
485
475
The data loader for the test split.
486
476
"""
487
477
488
- train_images , train_labels , val_images , val_labels , test_images , test_labels = \
478
+ train_images , train_labels , val_images , val_labels , test_images , test_labels = (
489
479
_split_data_paths_into_training_dataset (train_dataset_file )
480
+ )
490
481
491
482
if raw_transform is None :
492
483
raw_transform = standardize
493
484
494
- train_set = StaccImageCollectionDataset (train_images , train_labels , patch_shape , eps = eps , sigma = sigma ,
495
- lower_bound = lower_bound , upper_bound = upper_bound ,
496
- raw_transform = raw_transform )
497
- val_set = StaccImageCollectionDataset (val_images , val_labels , patch_shape , eps = eps , sigma = sigma ,
498
- lower_bound = lower_bound , upper_bound = upper_bound ,
499
- raw_transform = raw_transform )
485
+ train_set = StaccImageCollectionDataset (
486
+ train_images ,
487
+ train_labels ,
488
+ patch_shape ,
489
+ eps = eps ,
490
+ sigma = sigma ,
491
+ lower_bound = lower_bound ,
492
+ upper_bound = upper_bound ,
493
+ raw_transform = raw_transform ,
494
+ )
495
+ val_set = StaccImageCollectionDataset (
496
+ val_images ,
497
+ val_labels ,
498
+ patch_shape ,
499
+ eps = eps ,
500
+ sigma = sigma ,
501
+ lower_bound = lower_bound ,
502
+ upper_bound = upper_bound ,
503
+ raw_transform = raw_transform ,
504
+ )
500
505
501
506
train_dataloader = DataLoader (train_set , batch_size = batch_size , shuffle = True , num_workers = n_workers )
502
507
val_dataloader = DataLoader (val_set , batch_size = batch_size , shuffle = True , num_workers = n_workers )
@@ -507,9 +512,16 @@ def get_stacc_data_loader(
507
512
if test_images is None :
508
513
test_dataloader = None
509
514
else :
510
- test_set = StaccImageCollectionDataset (test_images , test_labels , patch_shape , eps = eps , sigma = sigma ,
511
- lower_bound = lower_bound , upper_bound = upper_bound ,
512
- raw_transform = raw_transform )
515
+ test_set = StaccImageCollectionDataset (
516
+ test_images ,
517
+ test_labels ,
518
+ patch_shape ,
519
+ eps = eps ,
520
+ sigma = sigma ,
521
+ lower_bound = lower_bound ,
522
+ upper_bound = upper_bound ,
523
+ raw_transform = raw_transform ,
524
+ )
513
525
test_dataloader = DataLoader (test_set , batch_size = batch_size , shuffle = True , num_workers = n_workers )
514
526
test_dataloader .shuffle = True
515
527
0 commit comments