88import  torch 
99
1010
11- def  _create (name , pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ):
11+ def  _create (name , pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ,  device = None ):
1212    """Creates a specified YOLOv5 model 
1313
1414    Arguments: 
@@ -18,6 +18,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
1818        classes (int): number of model classes 
1919        autoshape (bool): apply YOLOv5 .autoshape() wrapper to model 
2020        verbose (bool): print all information to screen 
21+         device (str, torch.device, None): device to use for model parameters 
2122
2223    Returns: 
2324        YOLOv5 pytorch model 
@@ -50,7 +51,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
5051                    model .names  =  ckpt ['model' ].names   # set class names attribute 
5152        if  autoshape :
5253            model  =  model .autoshape ()  # for file/URI/PIL/cv2/np inputs and NMS 
53-         device  =  select_device ('0'  if  torch .cuda .is_available () else  'cpu' )   # default to GPU if available 
54+         device  =  select_device ('0'  if  torch .cuda .is_available () else  'cpu' ) if   device   is   None   else   torch . device ( device ) 
5455        return  model .to (device )
5556
5657    except  Exception  as  e :
@@ -59,49 +60,49 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
5960        raise  Exception (s ) from  e 
6061
6162
62- def  custom (path = 'path/to/model.pt' , autoshape = True , verbose = True ):
63+ def  custom (path = 'path/to/model.pt' , autoshape = True , verbose = True ,  device = None ):
6364    # YOLOv5 custom or local model 
64-     return  _create (path , autoshape = autoshape , verbose = verbose )
65+     return  _create (path , autoshape = autoshape , verbose = verbose ,  device = device )
6566
6667
67- def  yolov5s (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ):
68+ def  yolov5s (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ,  device = None ):
6869    # YOLOv5-small model https://github.com/ultralytics/yolov5 
69-     return  _create ('yolov5s' , pretrained , channels , classes , autoshape , verbose )
70+     return  _create ('yolov5s' , pretrained , channels , classes , autoshape , verbose ,  device )
7071
7172
72- def  yolov5m (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ):
73+ def  yolov5m (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ,  device = None ):
7374    # YOLOv5-medium model https://github.com/ultralytics/yolov5 
74-     return  _create ('yolov5m' , pretrained , channels , classes , autoshape , verbose )
75+     return  _create ('yolov5m' , pretrained , channels , classes , autoshape , verbose ,  device )
7576
7677
77- def  yolov5l (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ):
78+ def  yolov5l (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ,  device = None ):
7879    # YOLOv5-large model https://github.com/ultralytics/yolov5 
79-     return  _create ('yolov5l' , pretrained , channels , classes , autoshape , verbose )
80+     return  _create ('yolov5l' , pretrained , channels , classes , autoshape , verbose ,  device )
8081
8182
82- def  yolov5x (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ):
83+ def  yolov5x (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ,  device = None ):
8384    # YOLOv5-xlarge model https://github.com/ultralytics/yolov5 
84-     return  _create ('yolov5x' , pretrained , channels , classes , autoshape , verbose )
85+     return  _create ('yolov5x' , pretrained , channels , classes , autoshape , verbose ,  device )
8586
8687
87- def  yolov5s6 (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ):
88+ def  yolov5s6 (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ,  device = None ):
8889    # YOLOv5-small-P6 model https://github.com/ultralytics/yolov5 
89-     return  _create ('yolov5s6' , pretrained , channels , classes , autoshape , verbose )
90+     return  _create ('yolov5s6' , pretrained , channels , classes , autoshape , verbose ,  device )
9091
9192
92- def  yolov5m6 (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ):
93+ def  yolov5m6 (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ,  device = None ):
9394    # YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5 
94-     return  _create ('yolov5m6' , pretrained , channels , classes , autoshape , verbose )
95+     return  _create ('yolov5m6' , pretrained , channels , classes , autoshape , verbose ,  device )
9596
9697
97- def  yolov5l6 (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ):
98+ def  yolov5l6 (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ,  device = None ):
9899    # YOLOv5-large-P6 model https://github.com/ultralytics/yolov5 
99-     return  _create ('yolov5l6' , pretrained , channels , classes , autoshape , verbose )
100+     return  _create ('yolov5l6' , pretrained , channels , classes , autoshape , verbose ,  device )
100101
101102
102- def  yolov5x6 (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ):
103+ def  yolov5x6 (pretrained = True , channels = 3 , classes = 80 , autoshape = True , verbose = True ,  device = None ):
103104    # YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5 
104-     return  _create ('yolov5x6' , pretrained , channels , classes , autoshape , verbose )
105+     return  _create ('yolov5x6' , pretrained , channels , classes , autoshape , verbose ,  device )
105106
106107
107108if  __name__  ==  '__main__' :
0 commit comments