diff --git a/src/operator/sequence_last-inl.h b/src/operator/sequence_last-inl.h index 3c3c8b0cd49e..78ade5e9de06 100644 --- a/src/operator/sequence_last-inl.h +++ b/src/operator/sequence_last-inl.h @@ -101,8 +101,8 @@ class SequenceLastOp : public Operator { using namespace mshadow::expr; int axis = param_.axis; - int out_size = out.size(0) * out.size(1); - int max_seq_len = data.size(axis); + index_t out_size = out.size(0) * out.size(1); + index_t max_seq_len = data.size(axis); index_t offset1 = axis ? out.size(1) : out_size; index_t offset2 = axis ? (max_seq_len * out.size(1)) : out.size(1); @@ -121,11 +121,11 @@ class SequenceLastOp : public Operator { using namespace mshadow::expr; auto axis = param_.axis; - int batch = out_grad.size(0); - int rest = out_grad.size(1); - int out_size = batch * rest; + index_t batch = out_grad.size(0); + index_t rest = out_grad.size(1); + index_t out_size = batch * rest; - int max_seq_len = in_grad.size(axis); + index_t max_seq_len = in_grad.size(axis); index_t offset1 = axis ? rest : out_size; index_t offset2 = axis ? (max_seq_len * rest) : rest; diff --git a/src/operator/sequence_last.cc b/src/operator/sequence_last.cc index 44869c518504..3a6cdbad6149 100644 --- a/src/operator/sequence_last.cc +++ b/src/operator/sequence_last.cc @@ -46,8 +46,14 @@ Operator *SequenceLastProp::CreateOperatorEx(Context ctx, DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]); } - // sequence_length not passed in, so fall back to using input array dtype for second argument - DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]); + // sequence_length not passed in, so fall back to using int32/int64 dtype for second argument + // second argument is the dtype of the sequence_length NDArray + // use int32 or int64 as index dtype based on build flag + #if MXNET_USE_INT64_TENSOR_SIZE == 1 + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], mshadow::kInt64); + #else + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], mshadow::kInt32); + #endif } DMLC_REGISTER_PARAMETER(SequenceLastParam); diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index e3407792f2da..169f5244d784 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -356,7 +356,9 @@ def test_sequence_last(): # test with sequence length # parameter sequence_length - NDArray with shape (batch_size) # (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2 - b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]), + # need to mention dtype = int64 for sequence_length ndarray to support large indices + # else it defaults to float32 and errors + b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3], dtype="int64"), use_sequence_length=True) # check if it takes 2nd sequence from the first batch assert b[0] == a[1][0]