Skip to content

Commit 0d8a184

Browse files
Add crops = results.crops() dictionary (#4676)
* adding get cropped functionality * Add target logic in existing functions * Crops cleanup * Add dictionary keys: conf, cls, box * Bug fixes - avoid return after first image Co-authored-by: Glenn Jocher <[email protected]>
1 parent 8e94bf6 commit 0d8a184

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

models/common.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
365365
self.s = shape # inference BCHW shape
366366

367367
def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
368+
crops = []
368369
for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
369370
str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '
370371
if pred.shape[0]:
@@ -376,7 +377,9 @@ def display(self, pprint=False, show=False, save=False, crop=False, render=False
376377
for *box, conf, cls in reversed(pred): # xyxy, confidence, class
377378
label = f'{self.names[int(cls)]} {conf:.2f}'
378379
if crop:
379-
save_one_box(box, im, file=save_dir / 'crops' / self.names[int(cls)] / self.files[i])
380+
file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
381+
crops.append({'box': box, 'conf': conf, 'cls': cls, 'label': label,
382+
'im': save_one_box(box, im, file=file, save=save)})
380383
else: # all others
381384
annotator.box_label(box, label, color=colors(cls))
382385
im = annotator.im
@@ -395,6 +398,10 @@ def display(self, pprint=False, show=False, save=False, crop=False, render=False
395398
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
396399
if render:
397400
self.imgs[i] = np.asarray(im)
401+
if crop:
402+
if save:
403+
LOGGER.info(f'Saved results to {save_dir}\n')
404+
return crops
398405

399406
def print(self):
400407
self.display(pprint=True) # print results
@@ -408,10 +415,9 @@ def save(self, save_dir='runs/detect/exp'):
408415
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
409416
self.display(save=True, save_dir=save_dir) # save results
410417

411-
def crop(self, save_dir='runs/detect/exp'):
412-
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
413-
self.display(crop=True, save_dir=save_dir) # crop results
414-
LOGGER.info(f'Saved results to {save_dir}\n')
418+
def crop(self, save=True, save_dir='runs/detect/exp'):
419+
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) if save else None
420+
return self.display(crop=True, save=save, save_dir=save_dir) # crop results
415421

416422
def render(self):
417423
self.display(render=True) # render results

0 commit comments

Comments
 (0)