diff --git a/python/mxnet/io.py b/python/mxnet/io.py index b07f7c1bea57..201414e8f6e1 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -168,7 +168,10 @@ def __init__(self, data, label=None, pad=None, index=None, def __str__(self): data_shapes = [d.shape for d in self.data] - label_shapes = [l.shape for l in self.label] + if self.label: + label_shapes = [l.shape for l in self.label] + else: + label_shapes = None return "{}: data shapes: {} label shapes: {}".format( self.__class__.__name__, data_shapes, diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index e8aba38b8253..58ca1d74fbba 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -252,6 +252,17 @@ def check_libSVMIter_news_data(): check_libSVMIter_synthetic() check_libSVMIter_news_data() + +def test_DataBatch(): + from nose.tools import ok_ + from mxnet.io import DataBatch + import re + batch = DataBatch(data=[mx.nd.ones((2,3))]) + ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\)\] label shapes: None', str(batch))) + batch = DataBatch(data=[mx.nd.ones((2,3)), mx.nd.ones((7,8))], label=[mx.nd.ones((4,5))]) + ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\), \(7L?, 8L?\)\] label shapes: \[\(4L?, 5L?\)\]', str(batch))) + + @unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/7826") def test_CSVIter(): def check_CSVIter_synthetic():