-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Description
Search before asking
- I have searched the Supervision issues and found no similar feature requests.
Description
Add TIF support to InferenceSlicer to enable processing of multi-GB drone and aerial geospatial imagery.
Use case
High-resolution or large scale drone and aerial geospatial imagery is stored in GeoTIFF files which can exceed the memory of the host machine (50GB+). Using PIL to load these images into memory is not possible! A windowed read pattern is generally used to process the file in chunks, which is near exactly what the InferenceSlicer does now.
This would add rasterio as an optional dependency.
The callback() function will still receive a numpy array, but the dimensions will reflect the number of bands (color channels) in the TIF. It is up to the callback to select or convert the bands to suit the model inference.
A check of the TIF image's coordinate reference system is also required to ensure the CRS is a "projected" image (in XY coordinates suitable for 2D display rather than "world" coords like lat-lon).
Additional
Reference implementation for my use-case attached.
import supervision
import rasterio
import numpy
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
class RasterioWindowedInferenceSlicer(supervision.InferenceSlicer):
def __call__(self, image: rasterio.DatasetReader, progress_desc="InferenceSlicer") -> supervision.Detections:
"""
Perform tiled inference on the full image and return merged detections.
Args:
image (rasterio.DatasetReader): A GeoTIFF dataset for windowed read.
progress_desc (str): Description for progress bar.
Returns:
Detections: Merged detections across all slices.
"""
detections_list: list[supervision.Detections] = []
resolution_wh = image.shape
if not image.crs.is_projected:
raise ValueError(f"A projected coordinate system is required for pixel inference on a 2D map. Use gdalwarp to reproject. crs={image.crs.to_string()}")
offsets = self._generate_offset(
resolution_wh=resolution_wh,
slice_wh=self.slice_wh,
overlap_wh=self.overlap_wh,
)
with ThreadPoolExecutor(max_workers=self.thread_workers) as executor:
futures = [
executor.submit(self._run_callback, image, offset) for offset in offsets
]
with tqdm(total=offsets.shape[0], desc=progress_desc) as pb:
for future in as_completed(futures):
detections_list.append(future.result())
pb.update(1)
merged = supervision.Detections.merge(detections_list=detections_list)
if self.overlap_filter == supervision.OverlapFilter.NONE:
return merged
if self.overlap_filter == supervision.OverlapFilter.NON_MAX_SUPPRESSION:
return merged.with_nms(
threshold=self.iou_threshold,
overlap_metric=self.overlap_metric,
)
if self.overlap_filter == supervision.OverlapFilter.NON_MAX_MERGE:
return merged.with_nmm(
threshold=self.iou_threshold,
overlap_metric=self.overlap_metric,
)
warnings.warn(
f"Invalid overlap filter strategy: {self.overlap_filter}",
category=supervision.SupervisionWarnings,
)
return merged
def _run_callback(self, image: rasterio.DatasetReader, offset: numpy.ndarray) -> supervision.Detections:
"""
Run detection callback on a sliced portion of the image and adjust coordinates.
Args:
image (ImageType): The full image.
offset (numpy.ndarray): Coordinates `(x_min, y_min, x_max, y_max)` defining
the slice region.
Returns:
Detections: Detections adjusted to the full image coordinate system.
"""
# image_slice = crop_image(image=image, xyxy=offset)
xywh = supervision.xyxy_to_xywh(numpy.expand_dims(offset, axis=0)).squeeze()
window = rasterio.windows.Window(*xywh)
image_slice = image.read(window=window)
detections = self.callback(image_slice)
resolution_wh = image.shape
detections = supervision.detection.tools.inference_slicer.move_detections(
detections=detections,
offset=offset[:2],
resolution_wh=resolution_wh,
)
return detectionsAre you willing to submit a PR?
- Yes I'd like to help by submitting a PR!