Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring TensorRT Python interface #307

Merged
merged 2 commits into from
Feb 10, 2022
Merged

Conversation

zhiqwang
Copy link
Owner

Refactoring TensorRT Python interface and leaving some interfaces and parameters for subsequent use

import os
import torch

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

assert torch.cuda.is_available()
device = torch.device('cuda')

from yolort.utils import get_image_from_url, read_image_to_tensor
from yolort.v5 import letterbox, attempt_download
from yolort.runtime import PredictorTRT
from yolort.runtime.trt_helper import EngineBuilder
from yolort.runtime.yolo_graphsurgeon import YOLOGraphSurgeon

# Define some parameters
img_size = 640
stride = 64
fixed_shape = True
score_thresh = 0.35
iou_thresh = 0.45
detections_per_img = 100
precision = "fp32"

# yolov5s6.pt is downloaded from 'https://github.com/ultralytics/yolov5/releases/download/v6.0/yolov5n6.pt'
model_path = "yolov5n6.pt"

checkpoint_path = attempt_download(model_path)
onnx_path = "yolov5n6.onnx"
engine_path = "yolov5n6.engine"

img_source = "https://huggingface.co/spaces/zhiqwang/assets/resolve/main/bus.jpg"
# img_source = "https://huggingface.co/spaces/zhiqwang/assets/resolve/main/zidane.jpg"
img_raw = get_image_from_url(img_source)

# Pre Processing
image = letterbox(img_raw, new_shape=(img_size, img_size), stride=stride, auto=not fixed_shape)[0]
image = read_image_to_tensor(image)
image = image[None]
image = image.to(device)
image = image.contiguous()

# Export to ONNX models
yolo_gs = YOLOGraphSurgeon(model_path, input_sample=image, version="r6.0", enable_dynamic=False)
# Embed the `EfficientNMS_TRT` at the end of `LogitsDecoder`.
yolo_gs.register_nms(score_thresh=score_thresh, nms_thresh=iou_thresh, detections_per_img=detections_per_img)

yolo_gs.save(onnx_path)

# Build TensorRT Engine
engine_builder = EngineBuilder(verbose=False, workspace=12, precision=precision)
engine_builder.create_network(onnx_path)
engine_builder.create_engine(engine_path)

# Inference on TensorRT
engine = PredictorTRT(engine_path, device=device, precision=precision)
engine.warmup(img_size=image.shape)

# Inferencing
detections = engine.run_on_image(image)

@zhiqwang zhiqwang added API Library use interface code quality Code format and unit tests labels Feb 10, 2022
@zhiqwang zhiqwang merged commit 720cd32 into main Feb 10, 2022
@zhiqwang zhiqwang deleted the refactor-trt-interface branch February 10, 2022 09:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API Library use interface code quality Code format and unit tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant