Skip to content

Commit

Permalink
Sequence last fix (apache#16156)
Browse files Browse the repository at this point in the history
* seq last fix

* index tensor to have int64

* fix dtypes

* revert unnecessary changes

* if seq len not passed, pass int64 dtype

* dtype comment

* use int32 or int64 as index dtype based on build flag

* Trigger notification

* Trigger notification

* lint fix
  • Loading branch information
ChaiBapchya authored and larroy committed Sep 28, 2019
1 parent 53ae1ed commit c949cb8
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
12 changes: 6 additions & 6 deletions src/operator/sequence_last-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class SequenceLastOp : public Operator {
using namespace mshadow::expr;

int axis = param_.axis;
int out_size = out.size(0) * out.size(1);
int max_seq_len = data.size(axis);
index_t out_size = out.size(0) * out.size(1);
index_t max_seq_len = data.size(axis);
index_t offset1 = axis ? out.size(1) : out_size;
index_t offset2 = axis ? (max_seq_len * out.size(1)) : out.size(1);

Expand All @@ -121,11 +121,11 @@ class SequenceLastOp : public Operator {
using namespace mshadow::expr;

auto axis = param_.axis;
int batch = out_grad.size(0);
int rest = out_grad.size(1);
int out_size = batch * rest;
index_t batch = out_grad.size(0);
index_t rest = out_grad.size(1);
index_t out_size = batch * rest;

int max_seq_len = in_grad.size(axis);
index_t max_seq_len = in_grad.size(axis);
index_t offset1 = axis ? rest : out_size;
index_t offset2 = axis ? (max_seq_len * rest) : rest;

Expand Down
10 changes: 8 additions & 2 deletions src/operator/sequence_last.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,14 @@ Operator *SequenceLastProp::CreateOperatorEx(Context ctx,
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]);
}

// sequence_length not passed in, so fall back to using input array dtype for second argument
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]);
// sequence_length not passed in, so fall back to using int32/int64 dtype for second argument
// second argument is the dtype of the sequence_length NDArray
// use int32 or int64 as index dtype based on build flag
#if MXNET_USE_INT64_TENSOR_SIZE == 1
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], mshadow::kInt64);
#else
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], mshadow::kInt32);
#endif
}

DMLC_REGISTER_PARAMETER(SequenceLastParam);
Expand Down
4 changes: 3 additions & 1 deletion tests/nightly/test_large_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,9 @@ def test_sequence_last():
# test with sequence length
# parameter sequence_length - NDArray with shape (batch_size)
# (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2
b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]),
# need to mention dtype = int64 for sequence_length ndarray to support large indices
# else it defaults to float32 and errors
b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3], dtype="int64"),
use_sequence_length=True)
# check if it takes 2nd sequence from the first batch
assert b[0] == a[1][0]
Expand Down

0 comments on commit c949cb8

Please sign in to comment.