|
| 1 | +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Vision and LLM models for milk pouch detection.""" |
| 16 | + |
| 17 | +import math |
| 18 | +import subprocess |
| 19 | +from typing import Any, Optional |
| 20 | +import warnings |
| 21 | + |
| 22 | +from groundingdino.util import inference |
| 23 | +import numpy as np |
| 24 | +import ollama |
| 25 | +from sam2 import build_sam |
| 26 | +from sam2 import sam2_image_predictor |
| 27 | +import torch |
| 28 | + |
| 29 | +from official.projects.waste_identification_ml.llm_applications.milk_pouch_detection import models_utils |
| 30 | + |
| 31 | + |
| 32 | +# Suppress common warnings for a cleaner console output. |
| 33 | +warnings.filterwarnings('ignore', category=UserWarning) |
| 34 | +warnings.filterwarnings('ignore', category=FutureWarning) |
| 35 | + |
| 36 | + |
| 37 | +class VisionModels: |
| 38 | + """Encapsulates vision models for object detection and segmentation. |
| 39 | +
|
| 40 | + This class provides a high-level API for using Grounding DINO and SAM2. |
| 41 | + Models are loaded into memory once during initialization to avoid redundant |
| 42 | + loading and improve performance for sequential processing tasks. |
| 43 | +
|
| 44 | + Attributes: |
| 45 | + dino_model: The loaded Grounding DINO model. |
| 46 | + sam_predictor: The initialized SAM2 predictor instance. |
| 47 | + device: The PyTorch device (e.g., 'cuda' or 'cpu') the models run on. |
| 48 | + """ |
| 49 | + |
| 50 | + def __init__( |
| 51 | + self, |
| 52 | + dino_config_path: str, |
| 53 | + dino_weights_path: str, |
| 54 | + sam_config_path: str, |
| 55 | + sam_checkpoint_path: str, |
| 56 | + device: str = 'cuda', |
| 57 | + ) -> None: |
| 58 | + """Initializes the vision pipeline by loading and setting up models. |
| 59 | +
|
| 60 | + Args: |
| 61 | + dino_config_path: Path to the Grounding DINO configuration file. |
| 62 | + dino_weights_path: Path to the Grounding DINO model weights file. |
| 63 | + sam_config_path: Path to the SAM2 model configuration file. |
| 64 | + sam_checkpoint_path: Path to the SAM2 model checkpoint file. |
| 65 | + device: The hardware device to run models on (e.g., "cuda", "cpu"). |
| 66 | + """ |
| 67 | + self.device = torch.device(device) |
| 68 | + |
| 69 | + print('Loading Grounding DINO model...') |
| 70 | + self.dino_model = inference.load_model(dino_config_path, dino_weights_path) |
| 71 | + self.dino_model.to(self.device) |
| 72 | + print('✅ Grounding DINO model loaded.') |
| 73 | + |
| 74 | + print('Loading SAM2 model...') |
| 75 | + sam2_model = build_sam.build_sam2( |
| 76 | + sam_config_path, sam_checkpoint_path, device=self.device |
| 77 | + ) |
| 78 | + self.sam_predictor = sam2_image_predictor.SAM2ImagePredictor(sam2_model) |
| 79 | + print('✅ SAM2 predictor initialized.') |
| 80 | + |
| 81 | + def detect_objects( |
| 82 | + self, |
| 83 | + image_path: str, |
| 84 | + text_prompt: str, |
| 85 | + box_threshold: float = 0.25, |
| 86 | + text_threshold: float = 0.25, |
| 87 | + ) -> tuple[np.ndarray, np.ndarray, torch.Tensor, list[str]]: |
| 88 | + """Detects objects in an image using Grounding DINO based on a prompt. |
| 89 | +
|
| 90 | + Args: |
| 91 | + image_path: The file path to the input image. |
| 92 | + text_prompt: The text description of objects to detect. |
| 93 | + box_threshold: The confidence threshold for object bounding boxes. |
| 94 | + text_threshold: The confidence threshold for text-based labels. |
| 95 | +
|
| 96 | + Returns: |
| 97 | + A tuple containing: |
| 98 | + - image: The original image loaded as a NumPy array. |
| 99 | + - xyxy_boxes: Detected bounding boxes in [x1, y1, x2, y2] format. |
| 100 | + - scores: Confidence scores for each detected box. |
| 101 | + - labels: Text labels corresponding to each box. |
| 102 | + """ |
| 103 | + image, transformed_image = inference.load_image(image_path) |
| 104 | + transformed_image = transformed_image.to(self.device) |
| 105 | + |
| 106 | + boxes, scores, labels = inference.predict( |
| 107 | + model=self.dino_model, |
| 108 | + image=transformed_image, |
| 109 | + caption=text_prompt, |
| 110 | + box_threshold=box_threshold, |
| 111 | + text_threshold=text_threshold, |
| 112 | + ) |
| 113 | + |
| 114 | + xyxy_boxes = models_utils.convert_boxes_cxcywh_to_xyxy(boxes, image.shape) |
| 115 | + return image, xyxy_boxes, scores, labels |
| 116 | + |
| 117 | + def segment_objects( |
| 118 | + self, image_source: np.ndarray, xyxy_boxes: np.ndarray |
| 119 | + ) -> tuple[list[np.ndarray], list[torch.Tensor], list[np.ndarray]]: |
| 120 | + """Generates segmentation masks for given bounding boxes using SAM2. |
| 121 | +
|
| 122 | + Args: |
| 123 | + image_source: The source image as a NumPy array. |
| 124 | + xyxy_boxes: A NumPy array of bounding boxes in [x1, y1, x2, y2] format. |
| 125 | +
|
| 126 | + Returns: |
| 127 | + A tuple containing: |
| 128 | + - all_masks: A list of boolean segmentation masks. |
| 129 | + - all_scores: A list of confidence scores for each mask. |
| 130 | + - all_boxes: A list of the original bounding boxes. |
| 131 | + """ |
| 132 | + self.sam_predictor.set_image(image_source) |
| 133 | + |
| 134 | + all_masks, all_scores, all_boxes = [], [], [] |
| 135 | + for bbox in xyxy_boxes: |
| 136 | + box_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) |
| 137 | + if box_area < 0.25 * math.prod(image_source.shape[:2]): |
| 138 | + masks, scores, _ = self.sam_predictor.predict( |
| 139 | + point_coords=None, |
| 140 | + point_labels=None, |
| 141 | + box=bbox[None, :], # SAM expects a batch dimension. |
| 142 | + multimask_output=False, |
| 143 | + ) |
| 144 | + # Squeeze to remove batch and multi-mask dimensions. |
| 145 | + all_masks.append(masks.squeeze()) |
| 146 | + all_scores.append(scores) |
| 147 | + all_boxes.append(bbox) |
| 148 | + |
| 149 | + return all_masks, all_scores, all_boxes |
| 150 | + |
| 151 | + def process_image( |
| 152 | + self, image_path: str, text_prompt: str |
| 153 | + ) -> Optional[dict[str, Any]]: |
| 154 | + """Runs the full detection and segmentation pipeline on an image. |
| 155 | +
|
| 156 | + Args: |
| 157 | + image_path: The file path to the input image. |
| 158 | + text_prompt: The text description of objects to detect and segment. |
| 159 | +
|
| 160 | + Returns: |
| 161 | + A dictionary containing the processed data ('image', 'boxes', 'masks') |
| 162 | + or None if no objects were detected. |
| 163 | + """ |
| 164 | + print(f"\nProcessing '{image_path}'") |
| 165 | + image, boxes, _, _ = self.detect_objects(image_path, text_prompt) |
| 166 | + |
| 167 | + if boxes.shape[0] == 0: |
| 168 | + print('No objects detected.') |
| 169 | + return None |
| 170 | + |
| 171 | + masks, _, mask_boxes = self.segment_objects(image, boxes) |
| 172 | + print('Segmentation complete.') |
| 173 | + |
| 174 | + return { |
| 175 | + 'image': image, |
| 176 | + 'boxes': mask_boxes, |
| 177 | + 'masks': masks, |
| 178 | + } |
| 179 | + |
| 180 | + |
| 181 | +class LlmModels: |
| 182 | + """Provides an interface to interact with a local LLM via Ollama.""" |
| 183 | + |
| 184 | + def query_image_with_llm( |
| 185 | + self, image_path: str, prompt: str, model_name: str |
| 186 | + ) -> str: |
| 187 | + """Sends an image and a text prompt to a local Ollama LLM. |
| 188 | +
|
| 189 | + Args: |
| 190 | + image_path: Path to the image file. |
| 191 | + prompt: The question or prompt for the LLM. |
| 192 | + model_name: The name of the Ollama model to use (e.g., 'llava'). |
| 193 | +
|
| 194 | + Returns: |
| 195 | + The text response from the LLM. |
| 196 | + """ |
| 197 | + response: ollama.ChatResponse = ollama.chat( |
| 198 | + model=model_name, |
| 199 | + messages=[{'role': 'user', 'content': prompt, 'images': [image_path]}], |
| 200 | + options={ |
| 201 | + 'temperature': 0.0, |
| 202 | + }, |
| 203 | + ) |
| 204 | + return response['message']['content'] |
| 205 | + |
| 206 | + def stop_model(self, model_name: str) -> None: |
| 207 | + """Stops a running Ollama model to free up system resources. |
| 208 | +
|
| 209 | + This function executes the 'ollama stop' command-line instruction. |
| 210 | +
|
| 211 | + Args: |
| 212 | + model_name: The name of the Ollama model to stop. |
| 213 | + """ |
| 214 | + print(f'Attempting to stop Ollama model: {model_name}...') |
| 215 | + try: |
| 216 | + result = subprocess.run( |
| 217 | + ['ollama', 'stop', model_name], |
| 218 | + capture_output=True, |
| 219 | + text=True, |
| 220 | + check=False, |
| 221 | + ) |
| 222 | + if result.returncode == 0: |
| 223 | + print(f'✅ Successfully sent stop command for model: {model_name}') |
| 224 | + else: |
| 225 | + # This may not be an error if the model wasn't running. |
| 226 | + print( |
| 227 | + 'Info: Could not stop model (may not be running):' |
| 228 | + f' {result.stderr.strip()}' |
| 229 | + ) |
| 230 | + except FileNotFoundError: |
| 231 | + print( |
| 232 | + "⚠️ 'ollama' command not found. Is Ollama installed and in your PATH?" |
| 233 | + ) |
| 234 | + except subprocess.CalledProcessError as e: |
| 235 | + print(f'⚠️ An unexpected error occurred: {e}') |
0 commit comments