Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Sequence last fix #16156

Merged
merged 10 commits into from
Sep 16, 2019
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/operator/sequence_last-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class SequenceLastOp : public Operator {
using namespace mshadow::expr;

int axis = param_.axis;
int out_size = out.size(0) * out.size(1);
int max_seq_len = data.size(axis);
index_t out_size = out.size(0) * out.size(1);
index_t max_seq_len = data.size(axis);
index_t offset1 = axis ? out.size(1) : out_size;
index_t offset2 = axis ? (max_seq_len * out.size(1)) : out.size(1);

Expand All @@ -121,11 +121,11 @@ class SequenceLastOp : public Operator {
using namespace mshadow::expr;

auto axis = param_.axis;
int batch = out_grad.size(0);
int rest = out_grad.size(1);
int out_size = batch * rest;
index_t batch = out_grad.size(0);
index_t rest = out_grad.size(1);
index_t out_size = batch * rest;

int max_seq_len = in_grad.size(axis);
index_t max_seq_len = in_grad.size(axis);
index_t offset1 = axis ? rest : out_size;
index_t offset2 = axis ? (max_seq_len * rest) : rest;

Expand Down
11 changes: 9 additions & 2 deletions src/operator/sequence_last.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,15 @@ Operator *SequenceLastProp::CreateOperatorEx(Context ctx,
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]);
}

// sequence_length not passed in, so fall back to using input array dtype for second argument
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]);
// sequence_length not passed in, so fall back to using int32/int64 dtype for second argument
// second argument is the dtype of the sequence_length NDArray
// use int32 or int64 as index dtype based on build flag
#if MXNET_USE_INT64_TENSOR_SIZE == 1
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], mshadow::kInt64);
#else
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], mshadow::kInt32);
#endif

}

DMLC_REGISTER_PARAMETER(SequenceLastParam);
Expand Down
4 changes: 3 additions & 1 deletion tests/nightly/test_large_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,9 @@ def test_sequence_last():
# test with sequence length
# parameter sequence_length - NDArray with shape (batch_size)
# (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2
b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]),
# need to mention dtype = int64 for sequence_length ndarray to support large indices
# else it defaults to float32 and errors
b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3], dtype="int64"),
use_sequence_length=True)
# check if it takes 2nd sequence from the first batch
assert b[0] == a[1][0]
Expand Down