Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update test_cvnets.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Caenorst committed Mar 6, 2019
1 parent e10ffc4 commit 0920ae9
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions tests/python/tensorrt/test_cvnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,22 @@


def get_classif_model(model_name, use_tensorrt, ctx=mx.gpu(0), batch_size=128):
mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt)
mx.contrib.tensorrt.set_use_fp16(False)
h, w = 32, 32
net = gluoncv.model_zoo.get_model(model_name, pretrained=True)
data = mx.sym.var('data')

net.hybridize()
net.forward(mx.nd.zeros((batch_size, 3, h, w)))
net.export(model_name)
_sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, 0)
if use_tensorrt:
out = net(data)
softmax = mx.sym.SoftmaxOutput(out, name='softmax')
all_params = dict([(k, v.data()) for k, v in net.collect_params().items()])
executor = mx.contrib.tensorrt.tensorrt_bind(softmax, ctx=ctx, all_params=all_params,
data=(batch_size,3, h, w),
softmax_label=(batch_size,), grad_req='null',
force_rebind=True)
sym = _sym.get_backend_symbol('TensorRT')
mx.contrib.tensorrt.init_tensorrt_params(sym, arg_params, aux_params)
else:
# Convert gluon model to Symbolic
net.hybridize()
net.forward(mx.ndarray.zeros((batch_size, 3, h, w)))
net.export(model_name)
symbol, arg_params, aux_params = mx.model.load_checkpoint(model_name, 0)
executor = symbol.simple_bind(ctx=ctx, data=(batch_size, 3, h, w),
softmax_label=(batch_size,))
executor.copy_params_from(arg_params, aux_params)
sym = _sym
executor = sym.simple_bind(ctx=ctx, data=(batch_size, 3, h, w),
softmax_label=(batch_size,),
grad_req='null', force_rebind=True)
executor.copy_params_from(arg_params, aux_params)
return executor


Expand Down Expand Up @@ -126,7 +120,7 @@ def run_experiment_for(model_name, batch_size, num_workers):


def test_tensorrt_on_cifar_resnets(batch_size=32, tolerance=0.1, num_workers=1):
original_try_value = mx.contrib.tensorrt.get_use_tensorrt()
original_use_fp16 = mx.contrib.tensorrt.get_use_fp16()
try:
models = [
'cifar_resnet20_v1',
Expand Down Expand Up @@ -170,7 +164,7 @@ def test_tensorrt_on_cifar_resnets(batch_size=32, tolerance=0.1, num_workers=1):

print("Test duration: %.2f seconds" % test_duration)
finally:
mx.contrib.tensorrt.set_use_tensorrt(original_try_value)
mx.contrib.tensorrt.set_use_fp16(original_use_fp16)


if __name__ == '__main__':
Expand Down

0 comments on commit 0920ae9

Please sign in to comment.