Skip to content

Commit a48c519

Browse files
authored
Fix select_device() for Multi-GPU (ultralytics#6434)
* Fix `select_device()` for Multi-GPU Possible fix for ultralytics#6431 * Update torch_utils.py * Update torch_utils.py * Update torch_utils.py * Update torch_utils.py * Update * Update * Update * Update * Update * Update * Update * Update * Update
1 parent 7bb4d71 commit a48c519

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

utils/datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@
2929
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
3030
from utils.general import (LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
3131
segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
32-
from utils.torch_utils import torch_distributed_zero_first
32+
from utils.torch_utils import device_count, torch_distributed_zero_first
3333

3434
# Parameters
3535
HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
3636
IMG_FORMATS = ['bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp'] # include image suffixes
3737
VID_FORMATS = ['asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'wmv'] # include video suffixes
38-
DEVICE_COUNT = max(torch.cuda.device_count(), 1)
38+
DEVICE_COUNT = max(device_count(), 1) # number of CUDA devices
3939

4040
# Get orientation exif tag
4141
for orientation in ExifTags.TAGS.keys():

utils/torch_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ def git_describe(path=Path(__file__).parent): # path must be a directory
5353
return '' # not a git repository
5454

5555

56+
def device_count():
57+
# Returns number of CUDA devices available. Safe version of torch.cuda.device_count().
58+
try:
59+
cmd = 'nvidia-smi -L | wc -l'
60+
return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
61+
except Exception as e:
62+
return 0
63+
64+
5665
def select_device(device='', batch_size=0, newline=True):
5766
# device = 'cpu' or '0' or '0,1,2,3'
5867
s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
@@ -61,10 +70,10 @@ def select_device(device='', batch_size=0, newline=True):
6170
if cpu:
6271
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
6372
elif device: # non-cpu device requested
64-
nd = torch.cuda.device_count() # number of CUDA devices
65-
assert torch.cuda.is_available(), 'CUDA is not available, use `--device cpu` or do not pass a --device'
73+
nd = device_count() # number of CUDA devices
6674
assert nd > int(max(device.split(','))), f'Invalid `--device {device}` request, valid devices are 0 - {nd - 1}'
67-
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable (must be after asserts)
75+
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
76+
assert torch.cuda.is_available(), 'CUDA is not available, use `--device cpu` or do not pass a --device'
6877

6978
cuda = not cpu and torch.cuda.is_available()
7079
if cuda:

0 commit comments

Comments
 (0)