From 11b994c79f96954e876a614cb393e527f35af1ec Mon Sep 17 00:00:00 2001 From: haijieg Date: Tue, 5 Jan 2016 11:51:41 -0800 Subject: [PATCH] fix a few unittests --- tests/python/gpu/test_conv.py | 2 +- tests/python/gpu/test_operator_gpu.py | 9 ++++++--- tests/python/train/test_conv.py | 2 -- tests/python/train/test_mlp.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/python/gpu/test_conv.py b/tests/python/gpu/test_conv.py index bbb9f386042b..9f2d4463e565 100644 --- a/tests/python/gpu/test_conv.py +++ b/tests/python/gpu/test_conv.py @@ -45,7 +45,7 @@ def get_iter(data_dir): logging.basicConfig(level=logging.DEBUG) num_gpus = 1 -data_dir = 's3://dmcl/mnist' +data_dir = 'data' (train, val) = get_iter(data_dir) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index f8f43e3d52dc..6f6516068a09 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1,7 +1,10 @@ +import os import sys -sys.path.insert(0, '../unittest') +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(curr_path, '..', 'unittest')) +print sys.path from test_operator import * if __name__ == '__main__': - test_softmax_with_shape((3,4), mx.gpu()) - test_multi_softmax_with_shape((3,4,5), mx.gpu()) \ No newline at end of file + test_softmax_with_shape((3,4), mx.gpu()) + test_multi_softmax_with_shape((3,4,5), mx.gpu()) diff --git a/tests/python/train/test_conv.py b/tests/python/train/test_conv.py index 5c0f11481316..ecaf5c4b0e92 100644 --- a/tests/python/train/test_conv.py +++ b/tests/python/train/test_conv.py @@ -1,6 +1,4 @@ # pylint: skip-file -import sys -sys.path.insert(0, '../../python') import mxnet as mx import numpy as np import os, pickle, gzip diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index c983b6eeac4f..0866f3623f44 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -1,7 +1,7 @@ # pylint: skip-file import mxnet as mx import numpy as np -import os, sys +import os import pickle as pickle import logging from common import get_data