Skip to content

Commit c4c0ee8

Browse files
authored
Apple MPS -> CPU NMS fallback strategy (#9600)
Until more ops are fully supported this update will allow for seamless MPS inference (but slower MPS to CPU transfer before NMS, so slower NMS times). Partially resolves #9596 Signed-off-by: Glenn Jocher <[email protected]> Signed-off-by: Glenn Jocher <[email protected]>
1 parent bd9c0c4 commit c4c0ee8

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

utils/general.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,8 @@ def non_max_suppression(
843843
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
844844
prediction = prediction[0] # select only inference output
845845

846+
if 'mps' in prediction.device.type: # MPS not fully supported yet, convert tensors to CPU before NMS
847+
prediction = prediction.cpu()
846848
bs = prediction.shape[0] # batch size
847849
nc = prediction.shape[2] - nm - 5 # number of classes
848850
xc = prediction[..., 4] > conf_thres # candidates

0 commit comments

Comments
 (0)