Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions models/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# YOLOv5 common modules

import math
from copy import copy
from pathlib import Path

import math
import numpy as np
import pandas as pd
import requests
Expand All @@ -12,7 +12,7 @@
from PIL import Image
from torch.cuda import amp

from utils.datasets import letterbox
from utils.datasets import exif_transpose, letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box
from utils.plots import colors, plot_one_box
from utils.torch_utils import time_synchronized
Expand Down Expand Up @@ -252,9 +252,10 @@ def forward(self, imgs, size=640, augment=False, profile=False):
for i, im in enumerate(imgs):
f = f'image{i}' # filename
if isinstance(im, str): # filename or uri
im, f = np.asarray(Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im)), im
im, f = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im), im
im = np.asarray(exif_transpose(im))
elif isinstance(im, Image.Image): # PIL Image
im, f = np.asarray(im), getattr(im, 'filename', f) or f
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename') or f
files.append(Path(f).with_suffix('.jpg').name)
if im.shape[0] < 5: # image in CHW
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
Expand Down
26 changes: 26 additions & 0 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,32 @@ def exif_size(img):
return s


def exif_transpose(image):
"""
Transpose a PIL image accordingly if it has an EXIF Orientation tag.
From https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py

:param image: The image to transpose.
:return: An image.
"""
exif = image.getexif()
orientation = exif.get(0x0112, 1) # default 1
if orientation > 1:
method = {2: Image.FLIP_LEFT_RIGHT,
3: Image.ROTATE_180,
4: Image.FLIP_TOP_BOTTOM,
5: Image.TRANSPOSE,
6: Image.ROTATE_270,
7: Image.TRANSVERSE,
8: Image.ROTATE_90,
}.get(orientation)
if method is not None:
image = image.transpose(method)
del exif[0x0112]
image.info["exif"] = exif.tobytes()
return image


def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
Expand Down