Skip to content

Commit

Permalink
Merge pull request apache#18 from dato-code/sframe_iter_bug_fix
Browse files Browse the repository at this point in the history
fix sframe iter padding
  • Loading branch information
Jay Gu committed Apr 7, 2016
2 parents a6ef6af + 8d05c0b commit b0160ba
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
4 changes: 3 additions & 1 deletion python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ def _infer_column_shape(self, sarray):
return (lengths.max(), )
elif dtype is gl.Image:
first_image = sarray.head(1)[0]
if first_image is None:
raise ValueError('Column cannot contain missing value')
return (first_image.channels, first_image.height, first_image.width)

def infer_shape(self):
Expand Down Expand Up @@ -514,7 +516,7 @@ def iter_next(self):
end = start + self.batch_size
if end >= self.data_size:
self.has_next = False
self.pad = self.data_size - end
self.pad = end - self.data_size
end = self.data_size
self._copy(start, end)
if self.pad > 0:
Expand Down
10 changes: 7 additions & 3 deletions tests/python/unittest/test_sframe_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,20 @@ def test_non_divisible_batch(self):
def test_padding(self):
padding = 5
batch_size = self.data_size + padding
shape_total = 1
for s in self.shape:
shape_total *= s
it = mxnet.io.SFrameIter(self.data, data_field=self.data_field,
label_field=self.label_field,
batch_size=batch_size)
label_expected = self.label_expected + [0.0] * padding
data_expected = self.data_expected + list(np.zeros(self.shape).flatten()) * padding
label_expected = self.label_expected + self.label_expected[:padding]
data_expected = self.data_expected + self.data_expected[:(padding * shape_total)]
label_actual = []
data_actual = []
for d in it:
data_actual.extend(d.data[0].asnumpy().flatten())
label_actual.extend(d.label[0].asnumpy().flatten())
self.assertEqual(d.pad, padding)
np.testing.assert_almost_equal(label_actual, label_expected)
np.testing.assert_almost_equal(data_actual, data_expected)

Expand All @@ -88,7 +92,7 @@ def test_missing_value(self):
self.data_field = [self.data_field]
for col in self.data_field:
ls = list(data[col])
ls[0] = None
ls[1] = None
data[col] = ls
it = mxnet.io.SFrameIter(data, data_field=self.data_field)
self.assertRaises(lambda: [it])
Expand Down

0 comments on commit b0160ba

Please sign in to comment.