From 474d348ba644f8d4521a85f84f50bb9559b609f8 Mon Sep 17 00:00:00 2001 From: r3stl355 Date: Thu, 3 Sep 2020 09:06:54 +0100 Subject: [PATCH] Truncate wiki text to match target shape --- python/mxnet/gluon/contrib/data/text.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/contrib/data/text.py b/python/mxnet/gluon/contrib/data/text.py index 916b41880d45..a4e3779eee68 100644 --- a/python/mxnet/gluon/contrib/data/text.py +++ b/python/mxnet/gluon/contrib/data/text.py @@ -91,8 +91,11 @@ def _get_data(self): data, label = self._read_batch(path) - self._data = nd.array(data, dtype=data.dtype).reshape((-1, self._seq_len)) - self._label = nd.array(label, dtype=label.dtype).reshape((-1, self._seq_len)) + # https://github.com/apache/incubator-mxnet/issues/18886 breaks this unless array size is + # multiple of self._seq_len. Truncating the source is consistent with pre #18886 outcome + seq_len_mult = len(data) // self._seq_len * self._seq_len + self._data = nd.array(data, dtype=data.dtype)[:seq_len_mult].reshape((-1, self._seq_len)) + self._label = nd.array(label, dtype=label.dtype)[:seq_len_mult].reshape((-1, self._seq_len)) def __getitem__(self, idx): return self._data[idx], self._label[idx]