Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Backport TRT fixes to 1.8 (#19983)
Browse files Browse the repository at this point in the history
* Update MXNet-TRT doc with the new optimize_for API

Signed-off-by: Serge Panev <[email protected]>

* [1.x] Move block.optimize_for backend_opts to kwargs (#19386)

* Move block.optimize_for backend_opts to kwargs

Signed-off-by: Serge Panev <[email protected]>

* Update Hybridize to use kwargs as backend opts

Signed-off-by: Serge Panev <[email protected]>

* Fix lint

Signed-off-by: Serge Panev <[email protected]>

* Change clear default to False and allow hybrize+optimize_for calls

Signed-off-by: Serge Panev <[email protected]>

* Fix nit

Signed-off-by: Serge Panev <[email protected]>

* Adress review comments

Signed-off-by: Serge Panev <[email protected]>

* Adress more review comments

Signed-off-by: Serge Panev <[email protected]>

* Adress more more review comments

Signed-off-by: Serge Panev <[email protected]>

* Fix nit

Signed-off-by: Serge Panev <[email protected]>

* Add 1:many conversions in nnvm_to_onnx and non-flatten GEMM (#19652)

Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L authored Mar 5, 2021
1 parent 035f721 commit c16aa91
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,74 +33,81 @@ from mxnet.gluon.model_zoo import vision
import time
import os

ctx=mx.gpu(0)

batch_shape = (1, 3, 224, 224)
resnet18 = vision.resnet18_v2(pretrained=True)
resnet18.hybridize()
resnet18.forward(mx.nd.zeros(batch_shape))
resnet18.export('resnet18_v2')
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet18_v2', 0)
x = mx.nd.zeros(batch_shape, ctx=ctx)

model = vision.resnet18_v2(pretrained=True, ctx=ctx)
model.hybridize(static_shape=True, static_alloc=True)

```
In our first section of code we import the modules needed to run MXNet, and to time our benchmark runs. We then download a pretrained version of Resnet18, hybridize it, and load it symbolically. It's important to note that the experimental version of TensorRT integration will only work with the symbolic MXNet API. If you're using Gluon, you must [hybridize](https://gluon.mxnet.io/chapter07_distributed-learning/hybridize.html) your computation graph and export it as a symbol before running inference. This may be addressed in future releases of MXNet, but in general if you're concerned about getting the best inference performance possible from your models, it's a good practice to hybridize.
In our first section of code we import the modules needed to run MXNet, and to time our benchmark runs. We then download a pretrained version of Resnet18. We hybridize (link to hybridization) it with static_alloc and static_shape to get the best performance.

## MXNet Baseline Performance
```python
# Create sample input
input = mx.nd.zeros(batch_shape)

# Execute with MXNet
executor = sym.simple_bind(ctx=mx.gpu(0), data=batch_shape, grad_req='null', force_rebind=True)
executor.copy_params_from(arg_params, aux_params)

# Warmup
print('Warming up MXNet')
for i in range(0, 10):
y_gen = executor.forward(is_train=False, data=input)
y_gen[0].wait_to_read()
for i in range(0, 1000):
out = model(x)
mx.nd.waitall()

# Timing
print('Starting MXNet timed run')
start = time.process_time()
start = time.time()
for i in range(0, 10000):
y_gen = executor.forward(is_train=False, data=input)
y_gen[0].wait_to_read()
end = time.time()
print(time.process_time() - start)
out = model(x)
mx.nd.waitall()
print(time.time() - start)
```

We are interested in inference performance, so to simplify the benchmark we'll pass a tensor filled with zeros as an input. We bind a symbol as usual, returning an MXNet executor, and we run forward on this executor in a loop. To help improve the accuracy of our benchmarks we run a small number of predictions as a warmup before running our timed loop. On a modern PC with an RTX 2070 GPU the time taken for our MXNet baseline is **17.20s**. Next we'll run the same model with TensorRT enabled, and see how the performance compares.
For this experiment we are strictly interested in inference performance, so to simplify the benchmark we'll pass a tensor filled with zeros as an input.
To help improve the accuracy of our benchmarks we run a small number of predictions as a warmup before running our timed loop. This will ensure various lazy operations, which do not represent real-world usage, have completed before we measure relative performance improvement. On a system with a V100 GPU, the time taken for our MXNet baseline is **19.5s** (512 samples/s).

## MXNet with TensorRT Integration Performance
```python
# Execute with TensorRT
print('Building TensorRT engine')
trt_sym = sym.get_backend_symbol('TensorRT')
arg_params, aux_params = mx.contrib.tensorrt.init_tensorrt_params(trt_sym, arg_params, aux_params)
mx.contrib.tensorrt.set_use_fp16(True)
executor = trt_sym.simple_bind(ctx=mx.gpu(), data=batch_shape,
grad_req='null', force_rebind=True)
executor.copy_params_from(arg_params, aux_params)
[...]

model.optimize_for(x, backend='TensorRT', static_alloc=True, static_shape=True)

[...]
```

We use a few TensorRT specific API calls from the contrib package here to setup our parameters and indicate we'd like to run inference in fp16 mode. We then call simple_bind as normal and copy our parameter dictionaries to our executor.
Next we'll run the same model with TensorRT enabled, and see how the performance compares.

To use TensorRT optimization with the Gluon, we need to call optimize_for with the TensorRT backend and provide some input data that will be used to infer shape and types (any sample representing the inference data). TensorRT backend supports only static shape, so we need to set static_alloc and static_shape to True.

This will run the subgraph partitioning and replace TensorRT compatible subgraphs with TensorRT ops containing the TensorRT engines. It's ready to be used.

```python
#Warmup
print('Warming up TensorRT')
for i in range(0, 10):
y_gen = executor.forward(is_train=False, data=input)
y_gen[0].wait_to_read()
# Warmup
for i in range(0, 1000):
out = model(x)
out[0].wait_to_read()

# Timing
print('Starting TensorRT timed run')
start = time.process_time()
start = time.time()
for i in range(0, 10000):
y_gen = executor.forward(is_train=False, data=input)
y_gen[0].wait_to_read()
end = time.time()
print(time.process_time() - start)
out = model(x)
out[0].wait_to_read()
print(time.time() - start)
```

We run timing with a warmup once more, and on the same machine, run in **9.83s**. A 1.75x speed improvement! Speed improvements when using libraries like TensorRT can come from a variety of optimizations, but in this case our speedups are coming from a technique known as [operator fusion](http://ziheng.org/2016/11/21/fusion-and-runtime-compilation-for-nnvm-and-tinyflow/).
We run timing with a warmup once again, and on the same machine, run in **12.7s** (787 samples/s). A 1.5x speed improvement! Speed improvements when using libraries like TensorRT can come from a variety of optimizations, but in this case our speedups are coming from a technique known as [operator fusion](http://ziheng.org/2016/11/21/fusion-and-runtime-compilation-for-nnvm-and-tinyflow/).

## FP16

We can give a simple speed up by turning on TensorRT FP16. This optimization comes almost as a freebie and doesn't need any other use effort than adding the optimize_for parameter precision.

```python
[...]

model.optimize_for(x, backend='TensorRT', static_alloc=True, static_shape=True, backend_opts={'precision':'fp16'})

[...]
```

We run timing with a warmup once more and we get **7.8s** (1282 samples/s). That's 2.5x speedup compared to the default MXNet!
All the ops used in ResNet-18 are FP16 compatible, so the TensorRT engine was able to run FP16 kernels, hence the extra speed up.


## Operators and Subgraph Fusion

Expand Down
6 changes: 3 additions & 3 deletions example/extensions/lib_pass/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ The `optimize_for` API takes at least 1 argument, `backend` which is a string th
For the Gluon API, `hybridize` can be called on HybridBlocks to execute a graph pass on the internal CachedOp Symbol.

```python
block.hybridize(backend=None, backend_opts=None, **kwargs)
block.hybridize(backend=None, **kwargs)
```

The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which pass that will be executed on the model. The `backend_opts` takes other user-specified options that will be passed to the backend APIs. The actual pass runs once just before the first the forward pass.
The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which pass that will be executed on the model. `**kwargs` might contain other user-specified options that will be passed to the backend APIs. The actual pass runs once just before the first the forward pass.

If you just want to run a graph pass on the HybridBlock but not run a complete forward pass, you can use the `optimize_for` API that combines the work done in the `hybridize` API with part of the work done in the forward pass.

```python
block.optimize_for(x, backend=None, backend_opts=None, **kwargs)
block.optimize_for(x, backend=None, **kwargs)
```

When the `optimize_for` API is called on a HybridBlock it runs the graph pass immediately. This lets users export the modified model without running a complete forward pass.
Expand Down
6 changes: 3 additions & 3 deletions example/extensions/lib_subgraph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@ The `optimize_for` API takes at least 1 argument, `backend` which is a string th
For the Gluon API, `hybridize` can be called on HybridBlocks to partition the internal CachedOp Symbol.

```python
block.hybridize(backend=None, backend_opts=None, clear=True, **kwargs)
block.hybridize(backend=None, clear=True, **kwargs)
```

The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which backend that will partition the model. The `backend_opts` are other user-specified options (as a Python dictionary of strings mapped to strings) that will be passed to the backend partitioning APIs. The `clear` argument defaults to `True` and clears any previous optimizations done on the block. If you want to chain optimizations together, set `clear` to `False`. The actual partitioning takes place during the forward pass. If you want to use `hybridize` to chain multiple optimizations, be sure to execute a forward pass after each call to `hybridize`.
The `hybridize` function prepares the HybridBlock to be converted into a backend symbol. The `backend` argument is a string that identifies which backend that will partition the model. `**kwargs` are other user-specified options (as a Python dictionary of strings mapped to strings) that will be passed to the backend partitioning APIs. The `clear` argument defaults to `False`, so it will chain optimizations together. If you want to clear any previous optimizations done on the block, set `clear` to `True`. The actual partitioning takes place during the forward pass. If you want to use `hybridize` to chain multiple optimizations, be sure to execute a forward pass after each call to `hybridize`.

If you just want to partition the HybridBlock but not run a complete forward pass, you can use the `optimize_for` API that combines the work done in the `hybridize` API with part of the work done in the forward pass.

```python
block.optimize_for(x, backend=None, backend_opts=None, clear=True, **kwargs)
block.optimize_for(x, backend=None, clear=False, **kwargs)
```

When the `optimize_for` API is called on a HybridBlock it partitions immediately. This lets users export the partitioned model without running a complete forward pass. Chaining multiple optimizations is as simple as calling `optimize_for` multiple times, no need to execute a forward pass (as opposed to `hybridize`).
Expand Down
6 changes: 3 additions & 3 deletions example/extensions/lib_subgraph/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test(backend):
inputs = [a,b]
sym_block = nn.SymbolBlock(sym, inputs)
sym_block.initialize()
sym_block.hybridize(backend=backend, backend_opts={'dedup_subgraph':True})
sym_block.hybridize(backend=backend, dedup_subgraph=True)
out2 = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
print(out2)

Expand All @@ -103,14 +103,14 @@ def test(backend):
sym_block2 = nn.SymbolBlock(sym, inputs)
sym_block2.initialize()
sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend=backend,
backend_opts={'dedup_subgraph':True})
dedup_subgraph=True)
sym_block2.export('partitioned')

# Test with additional input to subgraph op
print('-------------------------------')
print('Testing %s Gluon Hybridize partitioning with extra input' % backend)
sym_block2.optimize_for(mx.nd.ones((3,2)), mx.nd.ones((3,2)), backend="addInputPass",
clear=False, backend_opts={'dedup_subgraph':True})
dedup_subgraph=True)
out3 = sym_block2(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
print(out3)

Expand Down
72 changes: 54 additions & 18 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,13 @@ def _call_cached_op(self, *args):
out = [out]
return _regroup(out, self._out_format)

def optimize_for(self, x, *args, backend=None, backend_opts=None, clear=True, **kwargs):
def optimize_for(self, x, *args, backend=None, clear=False,
static_alloc=False,
static_shape=False,
inline_limit=2,
forward_bulk_size=None,
backward_bulk_size=None,
**kwargs):
"""Partitions the current HybridBlock and optimizes it for a given backend
without executing a forward pass. Modifies the HybridBlock in-place.
Expand Down Expand Up @@ -1108,19 +1114,29 @@ def optimize_for(self, x, *args, backend=None, backend_opts=None, clear=True, **
other inputs to model
backend : str
The name of backend, as registered in `SubgraphBackendRegistry`, default None
backend_opts : dict of user-specified options to pass to the backend for partitioning, optional
Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
clear : clears any previous optimizations
clear : bool, default False
Clears any previous optimizations
static_alloc : bool, default False
Statically allocate memory to improve speed. Memory usage may increase.
static_shape : bool, default False
Optimize for invariant input shapes between iterations. Must also
set static_alloc to True. Change of input shapes is still allowed
but slower.
inline_limit : optional int, default 2
Maximum number of operators that can be inlined.
forward_bulk_size : optional int, default None
Segment size of bulk execution during forward pass.
backward_bulk_size : optional int, default None
Segment size of bulk execution during forward pass.
**kwargs: The backend options, optional
Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
"""
if len(kwargs) > 0:
self._backend_opts = kwargs

# do hybrize API call
self.hybridize(True, backend, backend_opts, clear, **kwargs)
if clear or not self._active:
self.hybridize(True, backend, clear, static_alloc, static_shape,
inline_limit, forward_bulk_size, backward_bulk_size)

# do part of forward API call
has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args))
Expand Down Expand Up @@ -1155,7 +1171,12 @@ def register_child(self, block, name=None):
super(HybridBlock, self).register_child(block, name)
self._clear_cached_op()

def hybridize(self, active=True, backend=None, backend_opts=None, clear=True, **kwargs):
def hybridize(self, active=True, backend=None, clear=True,
static_alloc=False, static_shape=False,
inline_limit=2,
forward_bulk_size=None,
backward_bulk_size=None,
**kwargs):
"""Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on
non-hybrid children.
Expand All @@ -1165,32 +1186,47 @@ def hybridize(self, active=True, backend=None, backend_opts=None, clear=True, **
Whether to turn hybrid on or off.
backend : str
The name of backend, as registered in `SubgraphBackendRegistry`, default None
backend_opts : dict of user-specified options to pass to the backend for partitioning, optional
Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
clear : clears any previous optimizations
static_alloc : bool, default False
clear : bool, default True
Clears any previous optimizations
static_alloc : optional bool, default False
Statically allocate memory to improve speed. Memory usage may increase.
static_shape : bool, default False
static_shape : optional bool, default False
Optimize for invariant input shapes between iterations. Must also
set static_alloc to True. Change of input shapes is still allowed
but slower.
inline_limit : optional int, default 2
Maximum number of operators that can be inlined.
forward_bulk_size : optional int, default None
Segment size of bulk execution during forward pass.
backward_bulk_size : optional int, default None
Segment size of bulk execution during forward pass.
**kwargs: optional
Backend options.
"""
if len(kwargs) > 0:
self._backend_opts = kwargs

self._backend = backend
if backend_opts is not None:
assert isinstance(backend_opts, dict), \
"HybridBlock hybridize requires backend_opts to be a dictionary."
self._backend_opts = backend_opts

self._active = active
self._flags = list(kwargs.items())
self._flags = [("static_alloc", static_alloc), ("static_shape", static_shape),
("inline_limit", inline_limit)]
if forward_bulk_size is not None:
self._flags.append(("forward_bulk_size", forward_bulk_size))
if backward_bulk_size is not None:
self._flags.append(("backward_bulk_size", backward_bulk_size))
if clear:
self._clear_cached_op()
if active and self._forward_hooks or self._forward_pre_hooks:
warnings.warn('"{block}" is being hybridized while still having forward hook/pre-hook. '
'If "{block}" is a child of HybridBlock, the hooks will not take effect.'
.format(block=self))
super(HybridBlock, self).hybridize(active, **kwargs)
super(HybridBlock, self).hybridize(active,
static_alloc=static_alloc,
static_shape=static_shape,
inline_limit=inline_limit,
forward_bulk_size=forward_bulk_size,
backward_bulk_size=backward_bulk_size)

def cast(self, dtype):
self._clear_cached_op()
Expand Down
Loading

0 comments on commit c16aa91

Please sign in to comment.