Skip to content

Commit

Permalink
Fixes for CI
Browse files Browse the repository at this point in the history
  • Loading branch information
Marek Kolodziej committed Jul 17, 2018
1 parent 5402e42 commit 6c7dbeb
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 305 deletions.
2 changes: 1 addition & 1 deletion ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ unittest_ubuntu_tensorrt_gpu() {
export PYTHONPATH=./python/
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
nosetests-3.4 --verbose tests/python/tensorrt
nosetests-3.4 --verbose --processes=1 --process-restartworker tests/python/tensorrt
}

# quantization gpu currently only runs on P3 instances
Expand Down
46 changes: 46 additions & 0 deletions tests/python/tensorrt/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
# pylint: disable=unused-import
import unittest
# pylint: enable=unused-import
import numpy as np
import mxnet as mx
from ctypes.util import find_library

def check_tensorrt_installation():
assert find_library('nvinfer') is not None, "Can't find the TensorRT shared library"

def get_use_tensorrt():
return int(os.environ.get("MXNET_USE_TENSORRT", 0))

def set_use_tensorrt(status=False):
os.environ["MXNET_USE_TENSORRT"] = str(int(status))

def merge_dicts(*dict_args):
"""Merge arg_params and aux_params to populate shared_buffer"""
result = {}
for dictionary in dict_args:
result.update(dictionary)
return result

def get_fp16_infer_for_fp16_graph():
return int(os.environ.get("MXNET_TENSORRT_USE_FP16_FOR_FP32", 0))

def set_fp16_infer_for_fp16_graph(status=False):
os.environ["MXNET_TENSORRT_USE_FP16_FOR_FP32"] = str(int(status))
4 changes: 2 additions & 2 deletions tests/python/tensorrt/test_cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

import mxnet as mx
from test_tensorrt_lenet5 import *
from common import *

def detect_cycle_from(sym, visited, stack):
visited.add(sym.handle.value)
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_simple_cycle():
set_use_tensorrt(True)
executor = C.simple_bind(ctx=mx.gpu(0), data=(1,10), softmax_label=(1,),
shared_buffer=arg_params, grad_req='null', force_rebind=True)
assert has_no_cycle(executor.optimized_symbol), "The graph optimized by TRT contain a cycle"
assert has_no_cycle(executor.optimized_symbol), "The graph optimized by TRT contains a cycle"

if __name__ == '__main__':
test_simple_cycle()
Expand Down
24 changes: 4 additions & 20 deletions tests/python/tensorrt/test_tensorrt_lenet5.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,7 @@
# pylint: enable=unused-import
import numpy as np
import mxnet as mx
from ctypes.util import find_library

assert find_library('nvinfer') is not None, "Can't find the TensorRT shared library"

def get_use_tensorrt():
return int(os.environ.get("MXNET_USE_TENSORRT", 0))


def set_use_tensorrt(status=False):
os.environ["MXNET_USE_TENSORRT"] = str(int(status))


def merge_dicts(*dict_args):
"""Merge arg_params and aux_params to populate shared_buffer"""
result = {}
for dictionary in dict_args:
result.update(dictionary)
return result
from common import *


def get_iters(mnist, batch_size):
Expand Down Expand Up @@ -137,8 +120,8 @@ def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_siz


def test_tensorrt_inference():
"""Run inference comparison between MXNet and TensorRT.
This could be used stand-alone or with nosetests."""
"""Run LeNet-5 inference comparison between MXNet and TensorRT."""
check_tensorrt_installation()
mnist = mx.test_utils.get_mnist()
num_epochs = 10
batch_size = 1024
Expand All @@ -156,6 +139,7 @@ def test_tensorrt_inference():
# Load serialized MXNet model (model-symbol.json + model-epoch.params)
sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs)

print("LeNet-5 test")
print("Running inference in MXNet")
set_use_tensorrt(False)
mx_pct = run_inference(sym, arg_params, aux_params, mnist,
Expand Down
Loading

0 comments on commit 6c7dbeb

Please sign in to comment.