Skip to content

Commit 48e56d3

Browse files
authored
Add optional transforms argument to LoadStreams() (#9105)
* Add optional `transforms` argument to LoadStreams() Prepare for streaming classification support Signed-off-by: Glenn Jocher <[email protected]> * Cleanup Signed-off-by: Glenn Jocher <[email protected]> * fix * batch size > 1 fix Signed-off-by: Glenn Jocher <[email protected]>
1 parent d0fa004 commit 48e56d3

File tree

1 file changed

+25
-29
lines changed

1 file changed

+25
-29
lines changed

utils/dataloaders.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def __next__(self):
251251
s = f'image {self.count}/{self.nf} {path}: '
252252

253253
if self.transforms:
254-
im = self.transforms(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)) # classify transforms
254+
im = self.transforms(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)) # transforms
255255
else:
256256
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
257257
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
@@ -289,30 +289,28 @@ def __next__(self):
289289
raise StopIteration
290290

291291
# Read frame
292-
ret_val, img0 = self.cap.read()
293-
img0 = cv2.flip(img0, 1) # flip left-right
292+
ret_val, im0 = self.cap.read()
293+
im0 = cv2.flip(im0, 1) # flip left-right
294294

295295
# Print
296296
assert ret_val, f'Camera Error {self.pipe}'
297297
img_path = 'webcam.jpg'
298298
s = f'webcam {self.count}: '
299299

300-
# Padded resize
301-
img = letterbox(img0, self.img_size, stride=self.stride)[0]
300+
# Process
301+
im = letterbox(im0, self.img_size, stride=self.stride)[0] # resize
302+
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
303+
im = np.ascontiguousarray(im) # contiguous
302304

303-
# Convert
304-
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
305-
img = np.ascontiguousarray(img)
306-
307-
return img_path, img, img0, None, s
305+
return img_path, im, im0, None, s
308306

309307
def __len__(self):
310308
return 0
311309

312310

313311
class LoadStreams:
314312
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
315-
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
313+
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None):
316314
self.mode = 'stream'
317315
self.img_size = img_size
318316
self.stride = stride
@@ -326,7 +324,6 @@ def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
326324
n = len(sources)
327325
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
328326
self.sources = [clean_str(x) for x in sources] # clean source names for later
329-
self.auto = auto
330327
for i, s in enumerate(sources): # index, source
331328
# Start thread to read frames from video stream
332329
st = f'{i + 1}/{n}: {s}... '
@@ -353,8 +350,10 @@ def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
353350
LOGGER.info('') # newline
354351

355352
# check for common shapes
356-
s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
353+
s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
357354
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
355+
self.auto = auto and self.rect
356+
self.transforms = transforms # optional
358357
if not self.rect:
359358
LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
360359

@@ -385,18 +384,15 @@ def __next__(self):
385384
cv2.destroyAllWindows()
386385
raise StopIteration
387386

388-
# Letterbox
389-
img0 = self.imgs.copy()
390-
img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
391-
392-
# Stack
393-
img = np.stack(img, 0)
394-
395-
# Convert
396-
img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
397-
img = np.ascontiguousarray(img)
387+
im0 = self.imgs.copy()
388+
if self.transforms:
389+
im = np.stack([self.transforms(cv2.cvtColor(x, cv2.COLOR_BGR2RGB)) for x in im0]) # transforms
390+
else:
391+
im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
392+
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
393+
im = np.ascontiguousarray(im) # contiguous
398394

399-
return self.sources, img, img0, None, ''
395+
return self.sources, im, im0, None, ''
400396

401397
def __len__(self):
402398
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
@@ -836,7 +832,7 @@ def collate_fn(batch):
836832

837833
@staticmethod
838834
def collate_fn4(batch):
839-
img, label, path, shapes = zip(*batch) # transposed
835+
im, label, path, shapes = zip(*batch) # transposed
840836
n = len(shapes) // 4
841837
im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
842838

@@ -846,13 +842,13 @@ def collate_fn4(batch):
846842
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
847843
i *= 4
848844
if random.random() < 0.5:
849-
im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
850-
align_corners=False)[0].type(img[i].type())
845+
im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
846+
align_corners=False)[0].type(im[i].type())
851847
lb = label[i]
852848
else:
853-
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
849+
im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
854850
lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
855-
im4.append(im)
851+
im4.append(im1)
856852
label4.append(lb)
857853

858854
for i, lb in enumerate(label4):

0 commit comments

Comments
 (0)