Skip to content

Commit 338b111

Browse files
No public description
PiperOrigin-RevId: 794717467
1 parent 69505b3 commit 338b111

File tree

2 files changed

+281
-0
lines changed

2 files changed

+281
-0
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utility functions for milk pouch detection."""
16+
17+
from collections.abc import Mapping
18+
import dataclasses
19+
from typing import Any
20+
21+
import numpy as np
22+
import torch
23+
import torchvision
24+
25+
26+
@dataclasses.dataclass(frozen=True)
27+
class _BoundingBox:
28+
"""A class representing a bounding box."""
29+
x1: float
30+
y1: float
31+
x2: float
32+
y2: float
33+
34+
35+
def _box_area(box: _BoundingBox) -> float:
36+
"""Calculates the area of a bounding box.
37+
38+
Args:
39+
box: A _BoundingBox object.
40+
41+
Returns:
42+
The area of the bounding box.
43+
"""
44+
return max(0, box.x2 - box.x1) * max(0, box.y2 - box.y1)
45+
46+
47+
def _calculate_iou(
48+
box1: _BoundingBox,
49+
box2: _BoundingBox
50+
) -> float:
51+
"""Calculates the Intersection over Union (IoU) of two bounding boxes.
52+
53+
Args:
54+
box1: The first bounding box in (x1, y1, x2, y2) format.
55+
box2: The second bounding box in (x1, y1, x2, y2) format.
56+
57+
Returns:
58+
The IoU score, a float between 0.0 and 1.0.
59+
"""
60+
# Determine the coordinates of the intersection rectangle
61+
x1 = max(box1.x1, box2.x1)
62+
y1 = max(box1.y1, box2.y1)
63+
x2 = min(box1.x2, box2.x2)
64+
y2 = min(box1.y2, box2.y2)
65+
66+
# Calculate the area of intersection
67+
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
68+
69+
# Calculate the area of both bounding boxes
70+
box1_area = _box_area(box1)
71+
box2_area = _box_area(box2)
72+
73+
# Calculate the area of the union
74+
union_area = box1_area + box2_area - inter_area
75+
76+
# Compute the IoU score
77+
return inter_area / union_area if union_area != 0 else 0.0
78+
79+
80+
def _is_contained(
81+
inner_box: _BoundingBox,
82+
outer_box: _BoundingBox,
83+
margin: int = 5,
84+
) -> bool:
85+
"""Checks if one bounding box is contained within another, with a margin.
86+
87+
Args:
88+
inner_box: The bounding box that is potentially inside.
89+
outer_box: The bounding box that is potentially surrounding.
90+
margin: An optional pixel margin to allow for slight inaccuracies.
91+
92+
Returns:
93+
True if the inner box is contained within the outer box, False
94+
otherwise.
95+
"""
96+
return (
97+
inner_box.x1 >= outer_box.x1 - margin
98+
and inner_box.y1 >= outer_box.y1 - margin
99+
and inner_box.x2 <= outer_box.x2 + margin
100+
and inner_box.y2 <= outer_box.y2 + margin
101+
)
102+
103+
104+
def filter_boxes_keep_smaller(
105+
data: Mapping[str, list[Any]],
106+
iou_threshold: float = 0.8,
107+
area_threshold: int | None = None,
108+
min_area: int = 1000,
109+
margin: int = 5,
110+
) -> dict[str, list[Any]]:
111+
"""Filters overlapping bounding boxes, preferentially keeping smaller ones.
112+
113+
This function sorts boxes by area and iterates through them, discarding any
114+
box that has a high IoU with an already-kept box or is contained within one.
115+
This is useful for eliminating duplicate or redundant detections.
116+
117+
Args:
118+
data: A dictionary containing 'boxes' and 'masks' lists.
119+
iou_threshold: The IoU value above which a box is considered an overlap.
120+
area_threshold: An optional maximum area to consider for a box.
121+
min_area: The minimum area required for a box to be kept.
122+
margin: The pixel margin used for the containment check.
123+
124+
Returns:
125+
A dictionary with the filtered 'boxes' and their corresponding 'masks'.
126+
"""
127+
# Check if the input data is valid
128+
bounding_boxes = [_BoundingBox(*b) for b in data['boxes']]
129+
130+
areas = ([_box_area(b) for b in bounding_boxes])
131+
132+
# Sort boxes from smallest to largest area
133+
sorted_indices = np.argsort(areas)
134+
sorted_bounding_boxes = [bounding_boxes[i] for i in sorted_indices]
135+
136+
masks = np.array(data['masks'])
137+
sorted_masks = masks[sorted_indices]
138+
139+
kept_boxes = []
140+
kept_masks = []
141+
kept_bounding_boxes_for_check = []
142+
143+
for i, box in enumerate(sorted_bounding_boxes):
144+
current_area = _box_area(box)
145+
if (
146+
area_threshold is not None and current_area > area_threshold
147+
) or current_area < min_area:
148+
continue
149+
150+
keep = True
151+
for kept_box in kept_bounding_boxes_for_check:
152+
if _calculate_iou(box, kept_box) > iou_threshold or _is_contained(
153+
kept_box, box, margin
154+
):
155+
keep = False
156+
break
157+
158+
if keep:
159+
kept_boxes.append([box.x1, box.y1, box.x2, box.y2])
160+
kept_masks.append(sorted_masks[i])
161+
kept_bounding_boxes_for_check.append(box)
162+
163+
return {'boxes': kept_boxes, 'masks': kept_masks}
164+
165+
166+
def convert_boxes_cxcywh_to_xyxy(
167+
boxes: torch.Tensor, image_shape: tuple[int, int, int]
168+
) -> np.ndarray:
169+
"""Converts bounding boxes from center-based to corner-based format.
170+
171+
Args:
172+
boxes: A tensor of bounding boxes in (cx, cy, w, h) format.
173+
image_shape: A tuple representing the image dimensions (h, w, c).
174+
175+
Returns:
176+
A NumPy array of bounding boxes in (x1, y1, x2, y2) format.
177+
"""
178+
h, w, _ = image_shape
179+
scale_factors = torch.tensor([w, h, w, h], device=boxes.device)
180+
scaled_boxes = boxes * scale_factors
181+
xyxy_boxes = torchvision.ops.box_convert(
182+
boxes=scaled_boxes, in_fmt='cxcywh', out_fmt='xyxy'
183+
)
184+
return xyxy_boxes.cpu().numpy().astype(int)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
import torch
18+
from official.projects.waste_identification_ml.llm_applications.milk_pouch_detection import models_utils
19+
20+
21+
class UtilsTest(unittest.TestCase):
22+
"""Tests for the utility functions."""
23+
24+
def test_filter_boxes_keep_smaller_on_given_data(self):
25+
boxes = [
26+
[402.24, 0.54, 1343.04, 350.46],
27+
[402.24, 0.54, 955.20, 333.18],
28+
[930.24, 0.54, 1343.04, 351.54],
29+
[611.52, 0.54, 955.20, 334.26],
30+
[402.24, 0.54, 751.68, 305.10],
31+
[749.76, 592.38, 1055.04, 875.34],
32+
[941.76, 1012.50, 1039.68, 1078.38],
33+
]
34+
masks = [f"mask_{i}" for i in range(len(boxes))]
35+
data = {"boxes": boxes, "masks": masks}
36+
37+
expected_boxes = [
38+
[941.76, 1012.50, 1039.68, 1078.38],
39+
[749.76, 592.38, 1055.04, 875.34],
40+
[402.24, 0.54, 751.68, 305.10],
41+
[611.52, 0.54, 955.20, 334.26],
42+
[930.24, 0.54, 1343.04, 351.54],
43+
]
44+
45+
result = models_utils.filter_boxes_keep_smaller(data)
46+
actual_boxes = result["boxes"]
47+
48+
# Sort both lists for comparison (optional depending on importance of order)
49+
actual_sorted = sorted(actual_boxes)
50+
expected_sorted = sorted(expected_boxes)
51+
52+
self.assertEqual(len(actual_sorted), len(expected_sorted))
53+
for box1, box2 in zip(actual_sorted, expected_sorted):
54+
np.testing.assert_almost_equal(box1, box2, decimal=2)
55+
56+
def test_convert_boxes_cxcywh_to_xyxy_withsinglebox_returnscorrectcoordinates(
57+
self,
58+
):
59+
"""Tests that a single box is converted correctly."""
60+
boxes = torch.tensor([[0.5, 0.5, 0.2, 0.4]]) # cx, cy, w, h
61+
image_shape = (100, 200, 3) # h, w, c
62+
expected_boxes = np.array([[80, 30, 120, 70]]) # x1, y1, x2, y2
63+
64+
converted_boxes = models_utils.convert_boxes_cxcywh_to_xyxy(
65+
boxes, image_shape
66+
)
67+
np.testing.assert_array_equal(converted_boxes, expected_boxes)
68+
69+
def test_convert_boxes_cxcywh_to_xyxy_withmultipleboxes_returnscorrectcoordinates(
70+
self,
71+
):
72+
"""Tests that multiple boxes are converted correctly."""
73+
boxes = torch.tensor([
74+
[0.5, 0.5, 0.2, 0.4],
75+
[0.25, 0.25, 0.1, 0.1],
76+
])
77+
image_shape = (100, 200, 3)
78+
expected_boxes = np.array([
79+
[80, 30, 120, 70],
80+
[40, 20, 60, 30],
81+
])
82+
converted_boxes = models_utils.convert_boxes_cxcywh_to_xyxy(
83+
boxes, image_shape
84+
)
85+
np.testing.assert_array_equal(converted_boxes, expected_boxes)
86+
87+
boxes = torch.empty((0, 4))
88+
image_shape = (100, 200, 3)
89+
expected_boxes = np.empty((0, 4), dtype=int)
90+
converted_boxes = models_utils.convert_boxes_cxcywh_to_xyxy(
91+
boxes, image_shape
92+
)
93+
np.testing.assert_array_equal(converted_boxes, expected_boxes)
94+
95+
96+
if __name__ == "__main__":
97+
unittest.main()

0 commit comments

Comments
 (0)