Skip to content

Commit f47985c

Browse files
No public description
PiperOrigin-RevId: 795199950
1 parent 338b111 commit f47985c

File tree

1 file changed

+235
-0
lines changed
  • official/projects/waste_identification_ml/llm_applications/milk_pouch_detection

1 file changed

+235
-0
lines changed
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Vision and LLM models for milk pouch detection."""
16+
17+
import math
18+
import subprocess
19+
from typing import Any, Optional
20+
import warnings
21+
22+
from groundingdino.util import inference
23+
import numpy as np
24+
import ollama
25+
from sam2 import build_sam
26+
from sam2 import sam2_image_predictor
27+
import torch
28+
29+
from official.projects.waste_identification_ml.llm_applications.milk_pouch_detection import models_utils
30+
31+
32+
# Suppress common warnings for a cleaner console output.
33+
warnings.filterwarnings('ignore', category=UserWarning)
34+
warnings.filterwarnings('ignore', category=FutureWarning)
35+
36+
37+
class VisionModels:
38+
"""Encapsulates vision models for object detection and segmentation.
39+
40+
This class provides a high-level API for using Grounding DINO and SAM2.
41+
Models are loaded into memory once during initialization to avoid redundant
42+
loading and improve performance for sequential processing tasks.
43+
44+
Attributes:
45+
dino_model: The loaded Grounding DINO model.
46+
sam_predictor: The initialized SAM2 predictor instance.
47+
device: The PyTorch device (e.g., 'cuda' or 'cpu') the models run on.
48+
"""
49+
50+
def __init__(
51+
self,
52+
dino_config_path: str,
53+
dino_weights_path: str,
54+
sam_config_path: str,
55+
sam_checkpoint_path: str,
56+
device: str = 'cuda',
57+
) -> None:
58+
"""Initializes the vision pipeline by loading and setting up models.
59+
60+
Args:
61+
dino_config_path: Path to the Grounding DINO configuration file.
62+
dino_weights_path: Path to the Grounding DINO model weights file.
63+
sam_config_path: Path to the SAM2 model configuration file.
64+
sam_checkpoint_path: Path to the SAM2 model checkpoint file.
65+
device: The hardware device to run models on (e.g., "cuda", "cpu").
66+
"""
67+
self.device = torch.device(device)
68+
69+
print('Loading Grounding DINO model...')
70+
self.dino_model = inference.load_model(dino_config_path, dino_weights_path)
71+
self.dino_model.to(self.device)
72+
print('✅ Grounding DINO model loaded.')
73+
74+
print('Loading SAM2 model...')
75+
sam2_model = build_sam.build_sam2(
76+
sam_config_path, sam_checkpoint_path, device=self.device
77+
)
78+
self.sam_predictor = sam2_image_predictor.SAM2ImagePredictor(sam2_model)
79+
print('✅ SAM2 predictor initialized.')
80+
81+
def detect_objects(
82+
self,
83+
image_path: str,
84+
text_prompt: str,
85+
box_threshold: float = 0.25,
86+
text_threshold: float = 0.25,
87+
) -> tuple[np.ndarray, np.ndarray, torch.Tensor, list[str]]:
88+
"""Detects objects in an image using Grounding DINO based on a prompt.
89+
90+
Args:
91+
image_path: The file path to the input image.
92+
text_prompt: The text description of objects to detect.
93+
box_threshold: The confidence threshold for object bounding boxes.
94+
text_threshold: The confidence threshold for text-based labels.
95+
96+
Returns:
97+
A tuple containing:
98+
- image: The original image loaded as a NumPy array.
99+
- xyxy_boxes: Detected bounding boxes in [x1, y1, x2, y2] format.
100+
- scores: Confidence scores for each detected box.
101+
- labels: Text labels corresponding to each box.
102+
"""
103+
image, transformed_image = inference.load_image(image_path)
104+
transformed_image = transformed_image.to(self.device)
105+
106+
boxes, scores, labels = inference.predict(
107+
model=self.dino_model,
108+
image=transformed_image,
109+
caption=text_prompt,
110+
box_threshold=box_threshold,
111+
text_threshold=text_threshold,
112+
)
113+
114+
xyxy_boxes = models_utils.convert_boxes_cxcywh_to_xyxy(boxes, image.shape)
115+
return image, xyxy_boxes, scores, labels
116+
117+
def segment_objects(
118+
self, image_source: np.ndarray, xyxy_boxes: np.ndarray
119+
) -> tuple[list[np.ndarray], list[torch.Tensor], list[np.ndarray]]:
120+
"""Generates segmentation masks for given bounding boxes using SAM2.
121+
122+
Args:
123+
image_source: The source image as a NumPy array.
124+
xyxy_boxes: A NumPy array of bounding boxes in [x1, y1, x2, y2] format.
125+
126+
Returns:
127+
A tuple containing:
128+
- all_masks: A list of boolean segmentation masks.
129+
- all_scores: A list of confidence scores for each mask.
130+
- all_boxes: A list of the original bounding boxes.
131+
"""
132+
self.sam_predictor.set_image(image_source)
133+
134+
all_masks, all_scores, all_boxes = [], [], []
135+
for bbox in xyxy_boxes:
136+
box_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
137+
if box_area < 0.25 * math.prod(image_source.shape[:2]):
138+
masks, scores, _ = self.sam_predictor.predict(
139+
point_coords=None,
140+
point_labels=None,
141+
box=bbox[None, :], # SAM expects a batch dimension.
142+
multimask_output=False,
143+
)
144+
# Squeeze to remove batch and multi-mask dimensions.
145+
all_masks.append(masks.squeeze())
146+
all_scores.append(scores)
147+
all_boxes.append(bbox)
148+
149+
return all_masks, all_scores, all_boxes
150+
151+
def process_image(
152+
self, image_path: str, text_prompt: str
153+
) -> Optional[dict[str, Any]]:
154+
"""Runs the full detection and segmentation pipeline on an image.
155+
156+
Args:
157+
image_path: The file path to the input image.
158+
text_prompt: The text description of objects to detect and segment.
159+
160+
Returns:
161+
A dictionary containing the processed data ('image', 'boxes', 'masks')
162+
or None if no objects were detected.
163+
"""
164+
print(f"\nProcessing '{image_path}'")
165+
image, boxes, _, _ = self.detect_objects(image_path, text_prompt)
166+
167+
if boxes.shape[0] == 0:
168+
print('No objects detected.')
169+
return None
170+
171+
masks, _, mask_boxes = self.segment_objects(image, boxes)
172+
print('Segmentation complete.')
173+
174+
return {
175+
'image': image,
176+
'boxes': mask_boxes,
177+
'masks': masks,
178+
}
179+
180+
181+
class LlmModels:
182+
"""Provides an interface to interact with a local LLM via Ollama."""
183+
184+
def query_image_with_llm(
185+
self, image_path: str, prompt: str, model_name: str
186+
) -> str:
187+
"""Sends an image and a text prompt to a local Ollama LLM.
188+
189+
Args:
190+
image_path: Path to the image file.
191+
prompt: The question or prompt for the LLM.
192+
model_name: The name of the Ollama model to use (e.g., 'llava').
193+
194+
Returns:
195+
The text response from the LLM.
196+
"""
197+
response: ollama.ChatResponse = ollama.chat(
198+
model=model_name,
199+
messages=[{'role': 'user', 'content': prompt, 'images': [image_path]}],
200+
options={
201+
'temperature': 0.0,
202+
},
203+
)
204+
return response['message']['content']
205+
206+
def stop_model(self, model_name: str) -> None:
207+
"""Stops a running Ollama model to free up system resources.
208+
209+
This function executes the 'ollama stop' command-line instruction.
210+
211+
Args:
212+
model_name: The name of the Ollama model to stop.
213+
"""
214+
print(f'Attempting to stop Ollama model: {model_name}...')
215+
try:
216+
result = subprocess.run(
217+
['ollama', 'stop', model_name],
218+
capture_output=True,
219+
text=True,
220+
check=False,
221+
)
222+
if result.returncode == 0:
223+
print(f'✅ Successfully sent stop command for model: {model_name}')
224+
else:
225+
# This may not be an error if the model wasn't running.
226+
print(
227+
'Info: Could not stop model (may not be running):'
228+
f' {result.stderr.strip()}'
229+
)
230+
except FileNotFoundError:
231+
print(
232+
"⚠️ 'ollama' command not found. Is Ollama installed and in your PATH?"
233+
)
234+
except subprocess.CalledProcessError as e:
235+
print(f'⚠️ An unexpected error occurred: {e}')

0 commit comments

Comments
 (0)