From e8883e494a3641f55bda8a8f7ea6aefaa5f39e2f Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 12 Sep 2019 10:36:45 -0700 Subject: [PATCH] fix dtypes --- src/operator/sequence_last-inl.h | 8 ++++---- src/operator/sequence_last.cc | 6 +++--- tests/nightly/test_large_vector.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/operator/sequence_last-inl.h b/src/operator/sequence_last-inl.h index b77782e711a6..78ade5e9de06 100644 --- a/src/operator/sequence_last-inl.h +++ b/src/operator/sequence_last-inl.h @@ -121,11 +121,11 @@ class SequenceLastOp : public Operator { using namespace mshadow::expr; auto axis = param_.axis; - int batch = out_grad.size(0); - int rest = out_grad.size(1); - int out_size = batch * rest; + index_t batch = out_grad.size(0); + index_t rest = out_grad.size(1); + index_t out_size = batch * rest; - int max_seq_len = in_grad.size(axis); + index_t max_seq_len = in_grad.size(axis); index_t offset1 = axis ? rest : out_size; index_t offset2 = axis ? (max_seq_len * rest) : rest; diff --git a/src/operator/sequence_last.cc b/src/operator/sequence_last.cc index b53c9a934fd6..69886d5b891f 100644 --- a/src/operator/sequence_last.cc +++ b/src/operator/sequence_last.cc @@ -31,9 +31,9 @@ template <> Operator *CreateOp(SequenceLastParam param, int dtype, int itype) { Operator *op = nullptr; MSHADOW_TYPE_SWITCH(dtype, DType, { -// MSHADOW_TYPE_SWITCH(itype, IType, { - op = new SequenceLastOp(param); - // }); + MSHADOW_TYPE_SWITCH(itype, IType, { + op = new SequenceLastOp(param); + }); }); return op; } diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index b59a4462d922..65dd71153c5a 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -346,7 +346,7 @@ def test_sequence_reverse(): def test_sequence_last(): - a = nd.arange(0, LARGE_X * 2).reshape(LARGE_X, 2) + a = nd.arange(0, LARGE_X * 2, dtype="int64").reshape(LARGE_X, 2) # test if returns last sequence b = nd.SequenceLast(a) @@ -356,7 +356,7 @@ def test_sequence_last(): # test with sequence length # parameter sequence_length - NDArray with shape (batch_size) # (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2 - b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]), + b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3], dtype="int64"), use_sequence_length=True) # check if it takes 2nd sequence from the first batch assert b[0] == a[1][0]