Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Update MXNet-TRT doc with the new optimize_for API
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed Oct 21, 2020
1 parent 0bc01e9 commit 5a0601e
Showing 1 changed file with 52 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,74 +33,81 @@ from mxnet.gluon.model_zoo import vision
import time
import os

ctx=mx.gpu(0)

batch_shape = (1, 3, 224, 224)
resnet18 = vision.resnet18_v2(pretrained=True)
resnet18.hybridize()
resnet18.forward(mx.nd.zeros(batch_shape))
resnet18.export('resnet18_v2')
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet18_v2', 0)
x = mx.nd.zeros(batch_shape, ctx=ctx)

model = vision.resnet18_v2(pretrained=True, ctx=ctx)
model.hybridize(static_shape=True, static_alloc=True)

```
In our first section of code we import the modules needed to run MXNet, and to time our benchmark runs. We then download a pretrained version of Resnet18, hybridize it, and load it symbolically. It's important to note that the experimental version of TensorRT integration will only work with the symbolic MXNet API. If you're using Gluon, you must [hybridize](https://gluon.mxnet.io/chapter07_distributed-learning/hybridize.html) your computation graph and export it as a symbol before running inference. This may be addressed in future releases of MXNet, but in general if you're concerned about getting the best inference performance possible from your models, it's a good practice to hybridize.
In our first section of code we import the modules needed to run MXNet, and to time our benchmark runs. We then download a pretrained version of Resnet18. We hybridize (link to hybridization) it with static_alloc and static_shape to get the best performance.

## MXNet Baseline Performance
```python
# Create sample input
input = mx.nd.zeros(batch_shape)

# Execute with MXNet
executor = sym.simple_bind(ctx=mx.gpu(0), data=batch_shape, grad_req='null', force_rebind=True)
executor.copy_params_from(arg_params, aux_params)

# Warmup
print('Warming up MXNet')
for i in range(0, 10):
y_gen = executor.forward(is_train=False, data=input)
y_gen[0].wait_to_read()
for i in range(0, 1000):
out = model(x)
mx.nd.waitall()

# Timing
print('Starting MXNet timed run')
start = time.process_time()
start = time.time()
for i in range(0, 10000):
y_gen = executor.forward(is_train=False, data=input)
y_gen[0].wait_to_read()
end = time.time()
print(time.process_time() - start)
out = model(x)
mx.nd.waitall()
print(time.time() - start)
```

We are interested in inference performance, so to simplify the benchmark we'll pass a tensor filled with zeros as an input. We bind a symbol as usual, returning an MXNet executor, and we run forward on this executor in a loop. To help improve the accuracy of our benchmarks we run a small number of predictions as a warmup before running our timed loop. On a modern PC with an RTX 2070 GPU the time taken for our MXNet baseline is **17.20s**. Next we'll run the same model with TensorRT enabled, and see how the performance compares.
For this experiment we are strictly interested in inference performance, so to simplify the benchmark we'll pass a tensor filled with zeros as an input.
To help improve the accuracy of our benchmarks we run a small number of predictions as a warmup before running our timed loop. This will ensure various lazy operations, which do not represent real-world usage, have completed before we measure relative performance improvement. On a system with a V100 GPU, the time taken for our MXNet baseline is **19.5s** (512 samples/s).

## MXNet with TensorRT Integration Performance
```python
# Execute with TensorRT
print('Building TensorRT engine')
trt_sym = sym.get_backend_symbol('TensorRT')
arg_params, aux_params = mx.contrib.tensorrt.init_tensorrt_params(trt_sym, arg_params, aux_params)
mx.contrib.tensorrt.set_use_fp16(True)
executor = trt_sym.simple_bind(ctx=mx.gpu(), data=batch_shape,
grad_req='null', force_rebind=True)
executor.copy_params_from(arg_params, aux_params)
[...]

model.optimize_for(x, backend='TensorRT', static_alloc=True, static_shape=True)

[...]
```

We use a few TensorRT specific API calls from the contrib package here to setup our parameters and indicate we'd like to run inference in fp16 mode. We then call simple_bind as normal and copy our parameter dictionaries to our executor.
Next we'll run the same model with TensorRT enabled, and see how the performance compares.

To use TensorRT optimization with the Gluon, we need to call optimize_for with the TensorRT backend and provide some input data that will be used to infer shape and types (any sample representing the inference data). TensorRT backend supports only static shape, so we need to set static_alloc and static_shape to True.

This will run the subgraph partitioning and replace TensorRT compatible subgraphs with TensorRT ops containing the TensorRT engines. It's ready to be used.

```python
#Warmup
print('Warming up TensorRT')
for i in range(0, 10):
y_gen = executor.forward(is_train=False, data=input)
y_gen[0].wait_to_read()
# Warmup
for i in range(0, 1000):
out = model(x)
out[0].wait_to_read()

# Timing
print('Starting TensorRT timed run')
start = time.process_time()
start = time.time()
for i in range(0, 10000):
y_gen = executor.forward(is_train=False, data=input)
y_gen[0].wait_to_read()
end = time.time()
print(time.process_time() - start)
out = model(x)
out[0].wait_to_read()
print(time.time() - start)
```

We run timing with a warmup once more, and on the same machine, run in **9.83s**. A 1.75x speed improvement! Speed improvements when using libraries like TensorRT can come from a variety of optimizations, but in this case our speedups are coming from a technique known as [operator fusion](http://ziheng.org/2016/11/21/fusion-and-runtime-compilation-for-nnvm-and-tinyflow/).
We run timing with a warmup once again, and on the same machine, run in **12.7s** (787 samples/s). A 1.5x speed improvement! Speed improvements when using libraries like TensorRT can come from a variety of optimizations, but in this case our speedups are coming from a technique known as [operator fusion](http://ziheng.org/2016/11/21/fusion-and-runtime-compilation-for-nnvm-and-tinyflow/).

## FP16

We can give a simple speed up by turning on TensorRT FP16. This optimization comes almost as a freebie and doesn't need any other use effort than adding the optimize_for parameter precision.

```python
[...]

model.optimize_for(x, backend='TensorRT', static_alloc=True, static_shape=True, backend_opts={'precision':'fp16'})

[...]
```

We run timing with a warmup once more and we get **7.8s** (1282 samples/s). That's 2.5x speedup compared to the default MXNet!
All the ops used in ResNet-18 are FP16 compatible, so the TensorRT engine was able to run FP16 kernels, hence the extra speed up.


## Operators and Subgraph Fusion

Expand Down

0 comments on commit 5a0601e

Please sign in to comment.