diff --git a/python/mxnet/io.py b/python/mxnet/io.py index fdb32408d191..00f7502580a7 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -478,6 +478,8 @@ def _infer_column_shape(self, sarray): return (lengths.max(), ) elif dtype is gl.Image: first_image = sarray.head(1)[0] + if first_image is None: + raise ValueError('Column cannot contain missing value') return (first_image.channels, first_image.height, first_image.width) def infer_shape(self): @@ -514,7 +516,7 @@ def iter_next(self): end = start + self.batch_size if end >= self.data_size: self.has_next = False - self.pad = self.data_size - end + self.pad = end - self.data_size end = self.data_size self._copy(start, end) if self.pad > 0: diff --git a/tests/python/unittest/test_sframe_iter.py b/tests/python/unittest/test_sframe_iter.py index d6bbe293d8ef..a2e1988094e1 100644 --- a/tests/python/unittest/test_sframe_iter.py +++ b/tests/python/unittest/test_sframe_iter.py @@ -63,16 +63,20 @@ def test_non_divisible_batch(self): def test_padding(self): padding = 5 batch_size = self.data_size + padding + shape_total = 1 + for s in self.shape: + shape_total *= s it = mxnet.io.SFrameIter(self.data, data_field=self.data_field, label_field=self.label_field, batch_size=batch_size) - label_expected = self.label_expected + [0.0] * padding - data_expected = self.data_expected + list(np.zeros(self.shape).flatten()) * padding + label_expected = self.label_expected + self.label_expected[:padding] + data_expected = self.data_expected + self.data_expected[:(padding * shape_total)] label_actual = [] data_actual = [] for d in it: data_actual.extend(d.data[0].asnumpy().flatten()) label_actual.extend(d.label[0].asnumpy().flatten()) + self.assertEqual(d.pad, padding) np.testing.assert_almost_equal(label_actual, label_expected) np.testing.assert_almost_equal(data_actual, data_expected) @@ -88,7 +92,7 @@ def test_missing_value(self): self.data_field = [self.data_field] for col in self.data_field: ls = list(data[col]) - ls[0] = None + ls[1] = None data[col] = ls it = mxnet.io.SFrameIter(data, data_field=self.data_field) self.assertRaises(lambda: [it])