Skip to content

Commit

Permalink
Adapt test_tensorrt.py:test_tensorrt_symbol for A100
Browse files Browse the repository at this point in the history
  • Loading branch information
DickJC123 committed Oct 6, 2020
1 parent 7defa26 commit 578c22f
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions tests/python/tensorrt/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import os
import sys
import ctypes
import mxnet as mx
from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array, c_str, mx_real_t
Expand All @@ -28,6 +29,10 @@
from mxnet import nd
from mxnet.gluon.model_zoo import vision

curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed, teardown

####################################
######### FP32/FP16 tests ##########
####################################
Expand Down Expand Up @@ -61,7 +66,7 @@ def get_baseline(input_data):
return output


def check_tensorrt_symbol(baseline, input_data, fp16_mode, tol):
def check_tensorrt_symbol(baseline, input_data, fp16_mode, rtol=None, atol=None):
sym, arg_params, aux_params = get_model(batch_shape=input_data.shape)
trt_sym = sym.optimize_for('TensorRT', args=arg_params, aux=aux_params, ctx=mx.gpu(0),
precision='fp16' if fp16_mode else 'fp32')
Expand All @@ -70,17 +75,18 @@ def check_tensorrt_symbol(baseline, input_data, fp16_mode, tol):
grad_req='null', force_rebind=True)

output = executor.forward(is_train=False, data=input_data)
assert_almost_equal(output[0].asnumpy(), baseline[0].asnumpy(), atol=tol[0], rtol=tol[1])
assert_almost_equal(output[0], baseline[0], rtol=rtol, atol=atol)

@with_seed()
def test_tensorrt_symbol():
batch_shape = (32, 3, 224, 224)
input_data = mx.nd.random.uniform(shape=(batch_shape), ctx=mx.gpu(0))
baseline = get_baseline(input_data)
print("Testing resnet50 with TensorRT backend numerical accuracy...")
print("FP32")
check_tensorrt_symbol(baseline, input_data, fp16_mode=False, tol=(1e-4, 1e-4))
check_tensorrt_symbol(baseline, input_data, fp16_mode=False)
print("FP16")
check_tensorrt_symbol(baseline, input_data, fp16_mode=True, tol=(1e-1, 1e-2))
check_tensorrt_symbol(baseline, input_data, fp16_mode=True, rtol=1e-2, atol=1e-1)

##############################
######### INT8 tests ##########
Expand Down

0 comments on commit 578c22f

Please sign in to comment.