Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
reformat trt to use subgraph API, add fp16 support
Browse files Browse the repository at this point in the history
  • Loading branch information
Caenorst committed Mar 28, 2019
1 parent 645c778 commit ccae63f
Show file tree
Hide file tree
Showing 29 changed files with 1,006 additions and 2,371 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/onnx-tensorrt
1 change: 1 addition & 0 deletions ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ COPY runtime_functions.sh /work/

WORKDIR /work/mxnet
ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib
ENV CPLUS_INCLUDE_PATH=${CPLUS_INCLUDE_PATH}:/usr/local/cuda-10.0/targets/x86_64-linux/include/
1 change: 0 additions & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1877,7 +1877,6 @@ MXNET_DLL int MXExecutorReshape(int partial_shaping,
*/
MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle,
SymbolHandle *out);

/*!
* \brief set a call back to notify the completion of operation
*/
Expand Down
119 changes: 37 additions & 82 deletions python/mxnet/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,95 +16,50 @@
# under the License.

""" Module to enable the use of TensorRT optimized graphs."""

import ctypes
import logging
import os

from .. import symbol as sym

from ..base import _LIB, SymbolHandle, MXNetError
from ..base import check_call


def set_use_tensorrt(status):
def set_use_fp16(status):
"""
Set an environment variable which will enable or disable the use of TensorRT in the backend.
Note: this is useful for A/B testing purposes.
:param status: Boolean, true if TensorRT optimization should be applied, False for legacy
behaviour.
Set an environment variable which will enable or disable the use of FP16 precision in
TensorRT
Note: The mode FP16 force the whole TRT node to be executed in FP16
:param status: Boolean, True if TensorRT should run in FP16, False for FP32
"""
os.environ["MXNET_USE_TENSORRT"] = str(int(status))

os.environ["MXNET_TENSORRT_USE_FP16"] = str(int(status))

def get_use_tensorrt():
def get_use_fp16():
"""
Get an environment variable which describes if TensorRT is currently enabled in the backend.
Note: this is useful for A/B testing purposes.
:return: Boolean, true if TensorRT optimization should be applied, False for legacy
behaviour.
Get an environment variable which describes if TensorRT is currently running in FP16
:return: Boolean, true if TensorRT is running in FP16, False for FP32
"""
return bool(int(os.environ.get("MXNET_USE_TENSORRT", 0)) == 1)
return bool(int(os.environ.get("MXNET_TENSORRT_USE_FP16", 1)) == 1)


def get_optimized_symbol(executor):
def init_tensorrt_params(sym, arg_params, aux_params):
"""
Take an executor's underlying symbol graph and return its generated optimized version.
Parameters
----------
executor :
An executor for which you want to see an optimized symbol. Getting an optimized symbol
is useful to compare and verify the work TensorRT has done against a legacy behaviour.
Returns
-------
symbol : nnvm::Symbol
The nnvm symbol optimized.
"""
handle = SymbolHandle()
try:
check_call(_LIB.MXExecutorGetOptimizedSymbol(executor.handle, ctypes.byref(handle)))
result = sym.Symbol(handle=handle)
return result
except MXNetError:
logging.error('Error while trying to fetch TRT optimized symbol for graph. Please ensure '
'build was compiled with MXNET_USE_TENSORRT enabled.')
raise


def tensorrt_bind(symbol, ctx, all_params, type_dict=None, stype_dict=None, group2ctx=None,
**kwargs):
"""Bind current symbol to get an optimized trt executor.
Parameters
----------
symbol : Symbol
The symbol you wish to bind, and optimize with TensorRT.
ctx : Context
The device context the generated executor to run on.
all_params : Dict of str->ndarray
A dictionary of mappings from parameter names to parameter NDArrays.
type_dict : Dict of str->numpy.dtype
Input type dictionary, name->dtype
stype_dict : Dict of str->str
Input storage type dictionary, name->storage_type
group2ctx : Dict of string to mx.Context
The dict mapping the `ctx_group` attribute to the context assignment.
kwargs : Dict of str->shape
Input shape dictionary, name->shape
Returns
-------
executor : mxnet.Executor
An optimized TensorRT executor.
Set weights in attributes of TensorRT nodes
:param sym: Symbol, the symbol graph should contains some TensorRT nodes
:param arg_params: arg_params
:param aux_params: aux_params
:return arg_params, aux_params: remaining params that are not in TensorRT nodes
"""
kwargs['shared_buffer'] = all_params
return symbol.simple_bind(ctx, type_dict=type_dict, stype_dict=stype_dict,
group2ctx=group2ctx, **kwargs)
for s in sym.get_internals():
new_params_names = ""
tensorrt_params = {}
if 'subgraph_params_names' in s.list_attr():
keys = s.list_attr()['subgraph_params_names'].split(';')
for k in keys:
if k in arg_params:
new_params_names += k + ";"
tensorrt_params['subgraph_param_' + k] = arg_params[k]
arg_params.pop(k)
elif k in aux_params:
new_params_names += k + ";"
tensorrt_params['subgraph_param_' + k] = aux_params[k]
aux_params.pop(k)
new_attrs = {}
for k, v in tensorrt_params.items():
new_attrs[k] = str(v.handle.value)
if len(new_attrs) > 0:
s._set_attr(**new_attrs)
s._set_attr(subgraph_params_names=new_params_names[:-1])
return arg_params, aux_params
48 changes: 7 additions & 41 deletions src/c_api/c_api_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
#include <mxnet/executor.h>
#include "./c_api_common.h"
#include "../executor/graph_executor.h"
#if MXNET_USE_TENSORRT
#include "../executor/trt_graph_executor.h"
#endif // MXNET_USE_TENSORRT

int MXExecutorPrint(ExecutorHandle handle, const char **out_str) {
Executor *exec = static_cast<Executor*>(handle);
Expand Down Expand Up @@ -441,38 +438,12 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
std::vector<NDArray> in_arg_vec;
std::vector<NDArray> arg_grad_vec;
std::vector<NDArray> aux_state_vec;
#if MXNET_USE_TENSORRT
// If we've built with TensorRT support we by default return an TRTExecutor.
// Users can override this behaviour via env var, which is useful for example for A/B
// performance testing.
if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) {
*out = exec::TrtGraphExecutor::TensorRTBind(*sym, ctx, ctx_map, &in_arg_ctx_vec,
&arg_grad_ctx_vec, &aux_state_ctx_vec,
&arg_shape_map, &arg_dtype_map, &arg_stype_map,
&grad_req_type_vec, shared_arg_name_set,
&in_arg_vec, &arg_grad_vec, &aux_state_vec,
use_shared_buffer ? &shared_buffer_map : nullptr,
reinterpret_cast<Executor*>(shared_exec_handle));
} else {
// Checks to see if this env var has been set to true or false by the user.
// If the user is using a TensorRT build, but has not enabled TRT at inference time, warn
// them and describe further steps.
const int unset_indicator = std::numeric_limits<int>::quiet_NaN();
if (dmlc::GetEnv("MXNET_USE_TENSORRT", unset_indicator) == unset_indicator) {
LOG(INFO) << "TensorRT not enabled by default. Please set the MXNET_USE_TENSORRT "
"environment variable to 1 or call mx.contrib.tensorrt.set_use_tensorrt(True) "
"to enable.";
}
#endif // MXNET_USE_TENSORRT
*out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec,
aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map,
grad_req_type_vec, shared_arg_name_set, &in_arg_vec,
&arg_grad_vec, &aux_state_vec,
use_shared_buffer ? &shared_buffer_map : nullptr,
reinterpret_cast<Executor*>(shared_exec_handle));
#if MXNET_USE_TENSORRT
}
#endif // MXNET_USE_TENSORRT
*out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec,
aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map,
grad_req_type_vec, shared_arg_name_set, &in_arg_vec,
&arg_grad_vec, &aux_state_vec,
use_shared_buffer ? &shared_buffer_map : nullptr,
reinterpret_cast<Executor*>(shared_exec_handle));

// copy ndarray ptrs to ret->handles so that front end
// can access them
Expand Down Expand Up @@ -633,14 +604,9 @@ int MXExecutorGetOptimizedSymbol(ExecutorHandle handle,
auto s = new nnvm::Symbol();
API_BEGIN();

#if MXNET_USE_TENSORRT
auto exec = static_cast<exec::TrtGraphExecutor*>(handle);
auto exec = static_cast<exec::GraphExecutor*>(handle);
*s = exec->GetOptimizedSymbol();
*out = s;
#else
LOG(FATAL) << "GetOptimizedSymbol may only be used when MXNet is compiled with "
"MXNET_USE_TENSORRT enabled. Please re-compile MXNet with TensorRT support.";
#endif // MXNET_USE_TENSORRT

API_END_HANDLE_ERROR(delete s);
}
Expand Down
Loading

0 comments on commit ccae63f

Please sign in to comment.