Skip to content

Commit 797956a

Browse files
Merge pull request #23 from computational-cell-analytics/issue-15
Issue 15
2 parents 8daf4b0 + 277a4a4 commit 797956a

File tree

2 files changed

+55
-43
lines changed

2 files changed

+55
-43
lines changed

stacc/prediction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def main():
165165
model = get_model(args.model, args.custom_model)
166166

167167
# Check that all arguments are given for custom model
168-
if args.custom_model:
168+
if args.custom_model is not None and args.custom_distance is not None and args.custom_threshold is not None:
169169
if args.custom_distance is not None and args.custom_threshold is not None:
170170
min_distance = args.custom_distance
171171
threshold_abs = args.custom_threshold

stacc/training/dataset.py

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,7 @@ def _get_image_id(image_path: str) -> str:
3939

4040

4141
def apply_stamp_to_stacc_gt_label(
42-
stamp: np.ndarray,
43-
position: Tuple[int, int],
44-
label_matrix: np.ndarray,
45-
image_id: str
42+
stamp: np.ndarray, position: Tuple[int, int], label_matrix: np.ndarray, image_id: str
4643
) -> None:
4744
"""Add a stamp (small matrix) at the coordinates 'position' of the label_matrix.
4845
@@ -123,7 +120,7 @@ def create_gaussian_stamp(width: int, lower_bound: float, upper_bound: float, ep
123120

124121
stamp = np.zeros((width, width))
125122
stamp[width // 2, width // 2] = 1
126-
stamp = gaussian(stamp, sigma=sigma, truncate=10.0, mode='constant')
123+
stamp = gaussian(stamp, sigma=sigma, truncate=10.0, mode="constant")
127124
stamp[np.where(stamp < eps)] = 0 # Truncate to make the stamp circular
128125
stamp = stamp * 2 * 4 * np.pi * sigma**2
129126

@@ -139,16 +136,16 @@ def create_stacc_ground_truth_label_from_json(
139136
upper_bound: Optional[float] = None,
140137
bounding_box: Optional[Tuple[slice]] = None,
141138
):
142-
"""Create a ground truth label matrix from JSON bounding box annotations.
139+
"""Create a ground truth label matrix from JSON point annotations.
143140
144141
Args:
145142
image_path: Path to the input image.
146-
label_path: Path to the corresponding JSON file containing bounding box labels.
143+
label_path: Path to the corresponding JSON file containing point labels.
147144
eps: Epsilon value for Gaussian truncation. Default is 0.00001.
148145
sigma: Sigma value for Gaussian blur. If None, individual stamps are applied.
149146
lower_bound: Lower bound for Gaussian sigma. If None, no lower bound for sigma will be set.
150147
upper_bound: Upper bound for Gaussian sigma If None, no upper bound for sigma will be set.
151-
bounding_box: A bounding box for creating he labels only in a sub-part of the image.
148+
bounding_box: A bounding box for creating the labels only in a sub-part of the image.
152149
153150
Returns:
154151
A 2D array representing the stacc ground truth label matrix.
@@ -158,9 +155,9 @@ def create_stacc_ground_truth_label_from_json(
158155

159156
bboxes = label_dict["labels"]
160157
n_colonies = len(bboxes) # Number of annotations / bounding boxes
158+
image_id = _get_image_id(image_path) # Get image id and check if jpg or tif image
161159

162160
if bounding_box is None:
163-
image_id = _get_image_id(image_path) # Get image id and check if jpg or tif image
164161
im = imread(image_path)
165162
n_rows, n_columns = im.shape[:2]
166163
else:
@@ -171,17 +168,9 @@ def create_stacc_ground_truth_label_from_json(
171168
stacc_gt_label = np.zeros((n_rows, n_columns)) # Create empty ground truth label
172169
if n_colonies > 0:
173170
# Only keep the stacc_gt_label that are inside the image dimensions
174-
reasonable_indices = [
175-
i for i in range(n_colonies) if
176-
(bboxes[i]["x"] + max(int(bboxes[i]["width"]/2), 1)) <= n_columns and
177-
(bboxes[i]["y"] + max(int(bboxes[i]["height"]/2), 1)) <= n_rows
178-
]
179-
x_coordinates = np.array(
180-
[(int(bboxes[i]["y"]) + max(int(bboxes[i]["width"]/2), 1)) for i in reasonable_indices], dtype="int"
181-
)
182-
y_coordinates = np.array(
183-
[(int(bboxes[i]["x"]) + max(int(bboxes[i]["height"]/2), 1)) for i in reasonable_indices], dtype="int"
184-
)
171+
reasonable_indices = [i for i in range(n_colonies) if bboxes[i]["x"] < n_columns and bboxes[i]["y"] < n_rows]
172+
x_coordinates = np.array([int(bboxes[i]["y"]) for i in reasonable_indices], dtype="int")
173+
y_coordinates = np.array([int(bboxes[i]["x"]) for i in reasonable_indices], dtype="int")
185174

186175
# If we have a bounding box we need to subtract the lower corner and filter out points that are outside.
187176
if bounding_box is not None:
@@ -201,8 +190,8 @@ def create_stacc_ground_truth_label_from_json(
201190
else:
202191
# Process each coordinate individually
203192
for i, (x_coord, y_coord) in enumerate(zip(x_coordinates, y_coordinates)):
204-
width = max(int(bboxes[reasonable_indices[i]]['width']), 1)
205-
height = max(int(bboxes[reasonable_indices[i]]['height']), 1)
193+
width = max(int(bboxes[reasonable_indices[i]]["width"]), 1)
194+
height = max(int(bboxes[reasonable_indices[i]]["height"]), 1)
206195
width = min(width, height) # Make the stamp square
207196
if width % 2 == 0:
208197
width -= 1
@@ -267,8 +256,8 @@ def create_stacc_labels_from_csv(
267256

268257

269258
class StaccImageCollectionDataset(torch.utils.data.Dataset):
270-
"""@private
271-
"""
259+
"""@private"""
260+
272261
max_sampling_attempts = 500
273262

274263
def _check_inputs(self, raw_images, label_images):
@@ -337,10 +326,7 @@ def _sample_bounding_box(self, shape):
337326
raise NotImplementedError(
338327
f"Image padding is not supported yet. Data shape {shape}, patch shape {self.patch_shape}"
339328
)
340-
bb_start = [
341-
np.random.randint(0, sh - psh) if sh - psh > 0 else 0
342-
for sh, psh in zip(shape, self.patch_shape)
343-
]
329+
bb_start = [np.random.randint(0, sh - psh) if sh - psh > 0 else 0 for sh, psh in zip(shape, self.patch_shape)]
344330
return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape))
345331

346332
def _get_sample(self, index):
@@ -358,9 +344,13 @@ def _get_sample(self, index):
358344
_, label_extension = os.path.splitext(label_path)
359345
if label_extension == ".json":
360346
label_patch = create_stacc_ground_truth_label_from_json(
361-
raw_path, label_path, eps=self.eps, sigma=self.sigma,
362-
lower_bound=self.lower_bound, upper_bound=self.upper_bound,
363-
bounding_box=bb
347+
raw_path,
348+
label_path,
349+
eps=self.eps,
350+
sigma=self.sigma,
351+
lower_bound=self.lower_bound,
352+
upper_bound=self.upper_bound,
353+
bounding_box=bb,
364354
)
365355
elif label_extension.lower() == ".csv":
366356
if self.sigma is None:
@@ -383,7 +373,7 @@ def _get_sample(self, index):
383373
shape = shape[:-1]
384374
else:
385375
shape = shape[1:]
386-
prefix_box = (slice(None), )
376+
prefix_box = (slice(None),)
387377

388378
raw_patch = np.array(raw[prefix_box + bb])
389379

@@ -485,18 +475,33 @@ def get_stacc_data_loader(
485475
The data loader for the test split.
486476
"""
487477

488-
train_images, train_labels, val_images, val_labels, test_images, test_labels =\
478+
train_images, train_labels, val_images, val_labels, test_images, test_labels = (
489479
_split_data_paths_into_training_dataset(train_dataset_file)
480+
)
490481

491482
if raw_transform is None:
492483
raw_transform = standardize
493484

494-
train_set = StaccImageCollectionDataset(train_images, train_labels, patch_shape, eps=eps, sigma=sigma,
495-
lower_bound=lower_bound, upper_bound=upper_bound,
496-
raw_transform=raw_transform)
497-
val_set = StaccImageCollectionDataset(val_images, val_labels, patch_shape, eps=eps, sigma=sigma,
498-
lower_bound=lower_bound, upper_bound=upper_bound,
499-
raw_transform=raw_transform)
485+
train_set = StaccImageCollectionDataset(
486+
train_images,
487+
train_labels,
488+
patch_shape,
489+
eps=eps,
490+
sigma=sigma,
491+
lower_bound=lower_bound,
492+
upper_bound=upper_bound,
493+
raw_transform=raw_transform,
494+
)
495+
val_set = StaccImageCollectionDataset(
496+
val_images,
497+
val_labels,
498+
patch_shape,
499+
eps=eps,
500+
sigma=sigma,
501+
lower_bound=lower_bound,
502+
upper_bound=upper_bound,
503+
raw_transform=raw_transform,
504+
)
500505

501506
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=n_workers)
502507
val_dataloader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=n_workers)
@@ -507,9 +512,16 @@ def get_stacc_data_loader(
507512
if test_images is None:
508513
test_dataloader = None
509514
else:
510-
test_set = StaccImageCollectionDataset(test_images, test_labels, patch_shape, eps=eps, sigma=sigma,
511-
lower_bound=lower_bound, upper_bound=upper_bound,
512-
raw_transform=raw_transform)
515+
test_set = StaccImageCollectionDataset(
516+
test_images,
517+
test_labels,
518+
patch_shape,
519+
eps=eps,
520+
sigma=sigma,
521+
lower_bound=lower_bound,
522+
upper_bound=upper_bound,
523+
raw_transform=raw_transform,
524+
)
513525
test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=n_workers)
514526
test_dataloader.shuffle = True
515527

0 commit comments

Comments
 (0)