Skip to content

Commit

Permalink
Fix DataBatch.__str__ for cases where we don't have labels. (apache#9645
Browse files Browse the repository at this point in the history
)
  • Loading branch information
larroy authored and szha committed Feb 3, 2018
1 parent e150080 commit 793804d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ def __init__(self, data, label=None, pad=None, index=None,

def __str__(self):
data_shapes = [d.shape for d in self.data]
label_shapes = [l.shape for l in self.label]
if self.label:
label_shapes = [l.shape for l in self.label]
else:
label_shapes = None
return "{}: data shapes: {} label shapes: {}".format(
self.__class__.__name__,
data_shapes,
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ def check_libSVMIter_news_data():
check_libSVMIter_synthetic()
check_libSVMIter_news_data()


def test_DataBatch():
from nose.tools import ok_
from mxnet.io import DataBatch
import re
batch = DataBatch(data=[mx.nd.ones((2,3))])
ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\)\] label shapes: None', str(batch)))
batch = DataBatch(data=[mx.nd.ones((2,3)), mx.nd.ones((7,8))], label=[mx.nd.ones((4,5))])
ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\), \(7L?, 8L?\)\] label shapes: \[\(4L?, 5L?\)\]', str(batch)))


@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/7826")
def test_CSVIter():
def check_CSVIter_synthetic():
Expand Down

0 comments on commit 793804d

Please sign in to comment.