Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Truncate wiki text to match target shape
Browse files Browse the repository at this point in the history
  • Loading branch information
ulmasov committed Sep 3, 2020
1 parent 2ad74f5 commit e0219f8
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions python/mxnet/gluon/contrib/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,11 @@ def _get_data(self):

data, label = self._read_batch(path)

self._data = nd.array(data, dtype=data.dtype).reshape((-1, self._seq_len))
self._label = nd.array(label, dtype=label.dtype).reshape((-1, self._seq_len))
# https://github.com/apache/incubator-mxnet/issues/18886 breaks this unless array size is
# multiple of self._seq_len. Truncating the source is consistent with pre #18886 outcome
seq_len_mult = len(data) // self._seq_len * self._seq_len
self._data = nd.array(data, dtype=data.dtype)[:seq_len_mult].reshape((-1, self._seq_len))
self._label = nd.array(label, dtype=label.dtype)[:seq_len_mult].reshape((-1, self._seq_len))

def __getitem__(self, idx):
return self._data[idx], self._label[idx]
Expand Down

0 comments on commit e0219f8

Please sign in to comment.