Skip to content

How to compile Yolov5 for AWS Neuron #8619

@fishuke

Description

@fishuke

I faced with lack of documentation about this and i wanted to share how i compiled and infered yolov5s6.

Prerequisites:

Linux (not working with mac) preferably aws inf1 instance

Create a Python 3.7.10 venv.
Follow here
Install yolov requirements
pip install "torch-neuron==1.10.2.*" "neuron-cc[tensorflow]" "protobuf<4" torchvision==0.11.3 --extra-index-url https://pip.repos.neuron.amazonaws.com

Compiler:
import torch
import torch_neuron

model = torch.hub.load('ultralytics/yolov5', 'yolov5s6')  # whatever version do you need

for m in model.modules():
    if hasattr(m, 'inplace'):
        m.inplace = False

fake_image = torch.zeros([1, 3, 640, 640], dtype=torch.float32) # customize size here 640x640 is common

try:
    torch.neuron.analyze_model(model, example_inputs=[fake_image])
except Exception:
    torch.neuron.analyze_model(model, example_inputs=[fake_image])

model_neuron = torch.neuron.trace(model,
                                example_inputs=[fake_image])

## Export to saved model
model_neuron.save("neuron_yolov5s6.pt")
Inference:
import cv2
import torch
import numpy as np
import torch.neuron
from util_yolo import non_max_suppression

im = cv2.imread('img.jpg')
# img0 = im.copy()
im = cv2.resize(im, (640, 640), interpolation = cv2.INTER_AREA)
img0 = im.copy()
# Convert
im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im)
# Convert into torch
im = torch.from_numpy(im)
im = im.float()  # uint8 to fp16/32
im /= 255  # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
    im = im[None]  # expand for batch dim

# Load the compiled model
model = torch.jit.load('neuron_yolov5s6.pt')

CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
        'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
        'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
        'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
        'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
        'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
        'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
        'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
        'teddy bear', 'hair drier', 'toothbrush']

# Inference
pred = model(im)
pred = non_max_suppression(pred) #nms function used same as yolov5 detect.py

#Process predictions
for i, det in enumerate(pred):  # per image
    im0 = img0.copy()
    color=(30, 30, 30)
    txt_color=(255, 255, 255)
    h_size, w_size = im.shape[-2:]
    print(h_size, w_size)
    lw = max(round(sum(im.shape) / 2 * 0.003), 2)
    # cv2.rectangle(im0, (10,10), (200,200), (0,0,0))
    if len(det):
        # Write results
        for *xyxy, conf, cls in reversed(det):
            c = int(cls)  # integer class
            label = f'{CLASSES[c]} {conf:.2f}'
            print(label)
            #label = "human"
            box = xyxy
            # p1, p2 = (int(box[0]* w_size), int(box[1]* h_size)), (int(box[2]* w_size), int(box[3]* h_size))
            p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
            cv2.rectangle(im0, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
            print(f'p1={p1}, p2={p2}, box={xyxy}')
            tf = max(lw - 1, 1)  # font thickness
            w, h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0]  # text width, height
            outside = p1[1] - h - 3 >= 0  # label fits outside box
            p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
            cv2.rectangle(im0, p1, p2, color, -1, cv2.LINE_AA)  # filled
            cv2.putText(im0,
                        label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
                        0,
                        lw / 3,
                        txt_color,
                        thickness=tf,
                        lineType=cv2.LINE_AA)
    # Save results (image with detections)
    status = cv2.imwrite('out.jpg', im0)

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions