Skip to content

Commit 24d8af9

Browse files
authored
New val.py cuda variable (ultralytics#6957)
* New val.py `cuda` variable Fix for ONNX GPU val. * Update val.py
1 parent 51d2a2c commit 24d8af9

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

val.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def run(data,
143143
batch_size = model.batch_size
144144
else:
145145
device = model.device
146-
if not pt or jit:
146+
if not (pt or jit):
147147
batch_size = 1 # export.py models default to batch-size 1
148148
LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
149149

@@ -152,6 +152,7 @@ def run(data,
152152

153153
# Configure
154154
model.eval()
155+
cuda = device.type != 'cpu'
155156
is_coco = isinstance(data.get('val'), str) and data['val'].endswith('coco/val2017.txt') # COCO dataset
156157
nc = 1 if single_cls else int(data['nc']) # number of classes
157158
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for [email protected]:0.95
@@ -177,7 +178,7 @@ def run(data,
177178
pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
178179
for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
179180
t1 = time_sync()
180-
if pt or jit or engine:
181+
if cuda:
181182
im = im.to(device, non_blocking=True)
182183
targets = targets.to(device)
183184
im = im.half() if half else im.float() # uint8 to fp16/32

0 commit comments

Comments
 (0)