Skip to content

Commit

Permalink
Refactoring TensorRT Python interface (#307)
Browse files Browse the repository at this point in the history
* Leaving some interfaces and parameters for subsequent use

* Update API in tutorials
  • Loading branch information
zhiqwang authored Feb 10, 2022
1 parent e878451 commit 720cd32
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 109 deletions.
142 changes: 71 additions & 71 deletions notebooks/onnx-graphsurgeon-inference-tensorrt.ipynb

Large diffs are not rendered by default.

82 changes: 50 additions & 32 deletions yolort/runtime/trt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#

import logging
from pathlib import Path
from pathlib import Path, PosixPath
from typing import Optional, Tuple, Union

try:
Expand Down Expand Up @@ -41,13 +41,13 @@ class YOLOTRTModule(nn.Module):
Remove the ``torchvision::nms`` in this warpper, due to the fact that some third-party
inference frameworks currently do not support this operator very well.
Args:
checkpoint_path (string): Path of the trained YOLOv5 checkpoint.
version (string): Upstream YOLOv5 version. Default: 'r6.0'
"""

def __init__(
self,
checkpoint_path: str,
version: str = "r6.0",
):
def __init__(self, checkpoint_path: str, version: str = "r6.0"):
super().__init__()
model_info = load_from_ultralytics(checkpoint_path, version=version)

Expand Down Expand Up @@ -90,7 +90,7 @@ def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]:
@torch.no_grad()
def to_onnx(
self,
file_path: Union[str, Path],
file_path: Union[str, PosixPath],
input_sample: Optional[Tensor] = None,
opset_version: int = 11,
enable_dynamic: bool = True,
Expand All @@ -100,10 +100,11 @@ def to_onnx(
Saves the model in ONNX format.
Args:
file_path: The path of the file the onnx model should be saved to.
input_sample: An input for tracing. Default: None.
opset_version: Opset version we export the model to the onnx submodule. Default: 11.
enable_dynamic: Whether to specify axes of tensors as dynamic. Default: True.
file_path (Union[string, PosixPath]): The path of the file the onnx model should
be saved to.
input_sample (Tensor, Optional): An input for tracing. Default: None.
opset_version (int): Opset version we export the model to the onnx submodule. Default: 11.
enable_dynamic (bool): Whether to specify axes of tensors as dynamic. Default: True.
**kwargs: Will be passed to torch.onnx.export function.
"""
if input_sample is None:
Expand Down Expand Up @@ -140,12 +141,33 @@ class EngineBuilder:
Parses an ONNX graph and builds a TensorRT engine from it.
"""

def __init__(self, verbose=False, workspace=4):
def __init__(
self,
verbose: bool = False,
workspace: int = 4,
precision: str = "fp32",
enable_dynamic: bool = False,
max_batch_size: int = 16,
calib_input: Optional[str] = None,
calib_cache: Optional[str] = None,
calib_num_images: int = 5000,
calib_batch_size: int = 8,
):
"""
Args:
verbose: If enabled, a higher verbosity level will be
set on the TensorRT logger.
workspace: Max memory workspace to allow, in Gb.
verbose (bool): If enabled, a higher verbosity level will be
set on the TensorRT logger. Default: False
workspace (int): Max memory workspace to allow, in Gb.
precision (string): The datatype to use for the engine inference, either 'fp32',
'fp16' or 'int8'. Default: 'fp32'
enable_dynamic (bool): Whether to enable dynamic shapes. Default: False
max_batch_size (int): Maximum batch size reserved for dynamic shape inference. Default: 16
calib_input (string, optinal): The path to a directory holding the calibration images.
Default: None
calib_cache (string, optinal): The path where to write the calibration cache to,
or if it already exists, load it from. Default: None
calib_num_images (int): The maximum number of images to use for calibration. Default: 5000
calib_batch_size (int): The batch size to use for the calibration process. Default: 8
"""
self.logger = trt.Logger(trt.Logger.INFO)
if verbose:
Expand All @@ -161,6 +183,16 @@ def __init__(self, verbose=False, workspace=4):
self.network = None
self.parser = None

# Leaving some interfaces and parameters for subsequent use, but we have not yet
# implemented the following functionality
self.precision = precision
self.enable_dynamic = enable_dynamic
self.max_batch_size = max_batch_size
self.calib_input = calib_input
self.calib_cache = calib_cache
self.calib_num_images = calib_num_images
self.calib_batch_size = calib_batch_size

def create_network(self, onnx_path: str):
"""
Parse the ONNX graph and create the corresponding TensorRT network definition.
Expand All @@ -185,31 +217,17 @@ def create_network(self, onnx_path: str):
for output in outputs:
logger.info(f"Output '{output.name}' with shape {output.shape} and dtype {output.dtype}")

def create_engine(
self,
engine_path: str,
*,
precision: str = "fp32",
max_batch_size: int = 32,
calib_input: Optional[str] = None,
calib_cache: Optional[str] = None,
calib_num_images: int = 5000,
calib_batch_size: int = 8,
):
def create_engine(self, engine_path: str):
"""
Build the TensorRT engine and serialize it to disk.
Args:
engine_path: The path where to serialize the engine to.
precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'.
calib_input: The path to a directory holding the calibration images.
calib_cache: The path where to write the calibration cache to, or if it already
exists, load it from.
calib_num_images: The maximum number of images to use for calibration.
calib_batch_size: The batch size to use for the calibration process.
"""
engine_path = Path(engine_path)
engine_path.parent.mkdir(parents=True, exist_ok=True)

precision = self.precision
logger.info(f"Building {precision} Engine in {engine_path}")

# Process the batch size and profile
Expand Down
16 changes: 12 additions & 4 deletions yolort/runtime/y_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ class PredictorTRT:
single device for a single input image.
Args:
engine_path (str): Path of the ONNX checkpoint.
engine_path (string): Path of the ONNX checkpoint.
device (torch.device): The CUDA device to be used for inferencing.
precision (string): The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'.
Examples:
>>> import cv2
Expand All @@ -55,6 +56,7 @@ def __init__(
self,
engine_path: str,
device: torch.device = torch.device("cuda"),
precision: str = "fp32",
) -> None:
self.engine_path = engine_path
self.device = device
Expand All @@ -64,7 +66,13 @@ def __init__(

self.engine = self._build_engine()
self._set_context()
self.half = False

if precision == "fp32":
self.half = False
elif precision == "fp16":
self.half = True
else:
raise NotImplementedError(f"Currently not supports precision: {precision}")

def _build_engine(self):
logger.info(f"Loading {self.engine_path} for TensorRT inference...")
Expand Down Expand Up @@ -136,11 +144,11 @@ def postprocessing(all_boxes, all_scores, all_labels, all_num_dets):

return detections

def warmup(self, img_size=(1, 3, 320, 320), half=False):
def warmup(self, img_size=(1, 3, 320, 320)):
# Warmup model by running inference once
# only warmup GPU models
if isinstance(self.device, torch.device) and self.device.type != "cpu":
image = torch.zeros(*img_size).to(self.device).type(torch.half if half else torch.float)
image = torch.zeros(*img_size).to(self.device).type(torch.half if self.half else torch.float)
self(image)

def run_wo_postprocessing(self, image: Tensor):
Expand Down
14 changes: 12 additions & 2 deletions yolort/runtime/yolo_graphsurgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class YOLOGraphSurgeon:
values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0".
enable_dynamic (bool): Whether to specify axes of tensors as dynamic. Default: False.
device (torch.device): The device to be used for importing ONNX. Default: torch.device("cpu").
precision (string): The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'.
"""

def __init__(
Expand All @@ -60,6 +61,7 @@ def __init__(
version: str = "r6.0",
enable_dynamic: bool = False,
device: torch.device = torch.device("cpu"),
precision: str = "fp32",
):
checkpoint_path = Path(checkpoint_path)
assert checkpoint_path.exists()
Expand All @@ -82,6 +84,7 @@ def __init__(
self.graph.fold_constants()
self.num_classes = model.num_classes
self.batch_size = 1
self.precision = precision

def infer(self):
"""
Expand Down Expand Up @@ -165,6 +168,13 @@ def register_nms(
"box_coding": 0,
}

if self.precision == "fp32":
dtype_output = np.float32
elif self.precision == "fp16":
dtype_output = np.float16
else:
raise NotImplementedError(f"Currently not supports precision: {self.precision}")

# NMS Outputs
output_num_detections = gs.Variable(
name="num_detections",
Expand All @@ -173,12 +183,12 @@ def register_nms(
) # A scalar indicating the number of valid detections per batch image.
output_boxes = gs.Variable(
name="detection_boxes",
dtype=np.float32,
dtype=dtype_output,
shape=[self.batch_size, detections_per_img, 4],
)
output_scores = gs.Variable(
name="detection_scores",
dtype=np.float32,
dtype=dtype_output,
shape=[self.batch_size, detections_per_img],
)
output_labels = gs.Variable(
Expand Down

0 comments on commit 720cd32

Please sign in to comment.