Skip to content

Commit b9142d6

Browse files
Fix save_one_box() (ultralytics#5545)
* Fix `save_one_box()` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 480a429 commit b9142d6

File tree

4 files changed

+76
-76
lines changed

4 files changed

+76
-76
lines changed

detect.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
from models.experimental import attempt_load
2727
from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
2828
from utils.general import (LOGGER, apply_classifier, check_file, check_img_size, check_imshow, check_requirements,
29-
check_suffix, colorstr, increment_path, non_max_suppression, print_args, save_one_box,
30-
scale_coords, strip_optimizer, xyxy2xywh)
31-
from utils.plots import Annotator, colors
29+
check_suffix, colorstr, increment_path, non_max_suppression, print_args, scale_coords,
30+
strip_optimizer, xyxy2xywh)
31+
from utils.plots import Annotator, colors, save_one_box
3232
from utils.torch_utils import load_classifier, select_device, time_sync
3333

3434

models/common.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
from torch.cuda import amp
1919

2020
from utils.datasets import exif_transpose, letterbox
21-
from utils.general import (colorstr, increment_path, make_divisible, non_max_suppression, save_one_box, scale_coords,
22-
xyxy2xywh)
23-
from utils.plots import Annotator, colors
21+
from utils.general import colorstr, increment_path, make_divisible, non_max_suppression, scale_coords, xyxy2xywh
22+
from utils.plots import Annotator, colors, save_one_box
2423
from utils.torch_utils import time_sync
2524

2625
LOGGER = logging.getLogger(__name__)

utils/general.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -819,21 +819,6 @@ def apply_classifier(x, model, img, im0):
819819
return x
820820

821821

822-
def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
823-
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
824-
xyxy = torch.tensor(xyxy).view(-1, 4)
825-
b = xyxy2xywh(xyxy) # boxes
826-
if square:
827-
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
828-
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
829-
xyxy = xywh2xyxy(b).long()
830-
clip_coords(xyxy, im.shape)
831-
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
832-
if save:
833-
cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
834-
return crop
835-
836-
837822
def increment_path(path, exist_ok=False, sep='', mkdir=False):
838823
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
839824
path = Path(path) # os-agnostic

utils/plots.py

Lines changed: 71 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from PIL import Image, ImageDraw, ImageFont
1919

20-
from utils.general import is_ascii, is_chinese, user_config_dir, xywh2xyxy, xyxy2xywh
20+
from utils.general import clip_coords, increment_path, is_ascii, is_chinese, user_config_dir, xywh2xyxy, xyxy2xywh
2121
from utils.metrics import fitness
2222

2323
# Settings
@@ -117,6 +117,33 @@ def result(self):
117117
return np.asarray(self.im)
118118

119119

120+
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
121+
"""
122+
x: Features to be visualized
123+
module_type: Module type
124+
stage: Module stage within model
125+
n: Maximum number of feature maps to plot
126+
save_dir: Directory to save results
127+
"""
128+
if 'Detect' not in module_type:
129+
batch, channels, height, width = x.shape # batch, channels, height, width
130+
if height > 1 and width > 1:
131+
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
132+
133+
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
134+
n = min(n, channels) # number of plots
135+
fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
136+
ax = ax.ravel()
137+
plt.subplots_adjust(wspace=0.05, hspace=0.05)
138+
for i in range(n):
139+
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
140+
ax[i].axis('off')
141+
142+
print(f'Saving {save_dir / f}... ({n}/{channels})')
143+
plt.savefig(save_dir / f, dpi=300, bbox_inches='tight')
144+
plt.close()
145+
146+
120147
def hist2d(x, y, n=100):
121148
# 2d histogram used in labels.png and evolve.png
122149
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
@@ -337,37 +364,6 @@ def plot_labels(labels, names=(), save_dir=Path('')):
337364
plt.close()
338365

339366

340-
def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
341-
# Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
342-
ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
343-
s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
344-
files = list(Path(save_dir).glob('frames*.txt'))
345-
for fi, f in enumerate(files):
346-
try:
347-
results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
348-
n = results.shape[1] # number of rows
349-
x = np.arange(start, min(stop, n) if stop else n)
350-
results = results[:, x]
351-
t = (results[0] - results[0].min()) # set t0=0s
352-
results[0] = x
353-
for i, a in enumerate(ax):
354-
if i < len(results):
355-
label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
356-
a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
357-
a.set_title(s[i])
358-
a.set_xlabel('time (s)')
359-
# if fi == len(files) - 1:
360-
# a.set_ylim(bottom=0)
361-
for side in ['top', 'right']:
362-
a.spines[side].set_visible(False)
363-
else:
364-
a.remove()
365-
except Exception as e:
366-
print(f'Warning: Plotting error for {f}; {e}')
367-
ax[1].legend()
368-
plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
369-
370-
371367
def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
372368
# Plot evolve.csv hyp evolution results
373369
evolve_csv = Path(evolve_csv)
@@ -420,28 +416,48 @@ def plot_results(file='path/to/results.csv', dir=''):
420416
plt.close()
421417

422418

423-
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
424-
"""
425-
x: Features to be visualized
426-
module_type: Module type
427-
stage: Module stage within model
428-
n: Maximum number of feature maps to plot
429-
save_dir: Directory to save results
430-
"""
431-
if 'Detect' not in module_type:
432-
batch, channels, height, width = x.shape # batch, channels, height, width
433-
if height > 1 and width > 1:
434-
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
419+
def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
420+
# Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
421+
ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
422+
s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
423+
files = list(Path(save_dir).glob('frames*.txt'))
424+
for fi, f in enumerate(files):
425+
try:
426+
results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
427+
n = results.shape[1] # number of rows
428+
x = np.arange(start, min(stop, n) if stop else n)
429+
results = results[:, x]
430+
t = (results[0] - results[0].min()) # set t0=0s
431+
results[0] = x
432+
for i, a in enumerate(ax):
433+
if i < len(results):
434+
label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
435+
a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
436+
a.set_title(s[i])
437+
a.set_xlabel('time (s)')
438+
# if fi == len(files) - 1:
439+
# a.set_ylim(bottom=0)
440+
for side in ['top', 'right']:
441+
a.spines[side].set_visible(False)
442+
else:
443+
a.remove()
444+
except Exception as e:
445+
print(f'Warning: Plotting error for {f}; {e}')
446+
ax[1].legend()
447+
plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
435448

436-
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
437-
n = min(n, channels) # number of plots
438-
fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
439-
ax = ax.ravel()
440-
plt.subplots_adjust(wspace=0.05, hspace=0.05)
441-
for i in range(n):
442-
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
443-
ax[i].axis('off')
444449

445-
print(f'Saving {save_dir / f}... ({n}/{channels})')
446-
plt.savefig(save_dir / f, dpi=300, bbox_inches='tight')
447-
plt.close()
450+
def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
451+
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
452+
xyxy = torch.tensor(xyxy).view(-1, 4)
453+
b = xyxy2xywh(xyxy) # boxes
454+
if square:
455+
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
456+
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
457+
xyxy = xywh2xyxy(b).long()
458+
clip_coords(xyxy, im.shape)
459+
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
460+
if save:
461+
file.parent.mkdir(parents=True, exist_ok=True) # make directory
462+
cv2.imwrite(str(increment_path(file).with_suffix('.jpg')), crop)
463+
return crop

0 commit comments

Comments
 (0)