diff --git a/python/mxnet/gluon/contrib/data/text.py b/python/mxnet/gluon/contrib/data/text.py index 916b41880d45..9cfe5349d86b 100644 --- a/python/mxnet/gluon/contrib/data/text.py +++ b/python/mxnet/gluon/contrib/data/text.py @@ -91,8 +91,11 @@ def _get_data(self): data, label = self._read_batch(path) - self._data = nd.array(data, dtype=data.dtype).reshape((-1, self._seq_len)) - self._label = nd.array(label, dtype=label.dtype).reshape((-1, self._seq_len)) + # https://github.com/apache/incubator-mxnet/issues/18886 breaks this unless array size is + # multiple of self._seq_len. Truncating the source is consistent with pre #18886 outcome + seq_len_mult = len(data) // self._seq_len * self._seq_len + self._data = nd.array(data, dtype=data.dtype)[:seq_len_mult].reshape((-1, self._seq_len)) + self._label = nd.array(label, dtype=label.dtype)[:seq_len_mult].reshape((-1, self._seq_len)) def __getitem__(self, idx): return self._data[idx], self._label[idx]