Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaiBapchya committed Sep 12, 2019
1 parent 5b6d697 commit e8883e4
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions src/operator/sequence_last-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ class SequenceLastOp : public Operator {
using namespace mshadow::expr;

auto axis = param_.axis;
int batch = out_grad.size(0);
int rest = out_grad.size(1);
int out_size = batch * rest;
index_t batch = out_grad.size(0);
index_t rest = out_grad.size(1);
index_t out_size = batch * rest;

int max_seq_len = in_grad.size(axis);
index_t max_seq_len = in_grad.size(axis);
index_t offset1 = axis ? rest : out_size;
index_t offset2 = axis ? (max_seq_len * rest) : rest;

Expand Down
6 changes: 3 additions & 3 deletions src/operator/sequence_last.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ template <>
Operator *CreateOp<cpu>(SequenceLastParam param, int dtype, int itype) {
Operator *op = nullptr;
MSHADOW_TYPE_SWITCH(dtype, DType, {
// MSHADOW_TYPE_SWITCH(itype, IType, {
op = new SequenceLastOp<cpu, DType, int64_t>(param);
// });
MSHADOW_TYPE_SWITCH(itype, IType, {
op = new SequenceLastOp<cpu, DType, IType>(param);
});
});
return op;
}
Expand Down
4 changes: 2 additions & 2 deletions tests/nightly/test_large_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test_sequence_reverse():


def test_sequence_last():
a = nd.arange(0, LARGE_X * 2).reshape(LARGE_X, 2)
a = nd.arange(0, LARGE_X * 2, dtype="int64").reshape(LARGE_X, 2)

# test if returns last sequence
b = nd.SequenceLast(a)
Expand All @@ -356,7 +356,7 @@ def test_sequence_last():
# test with sequence length
# parameter sequence_length - NDArray with shape (batch_size)
# (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2
b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]),
b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3], dtype="int64"),
use_sequence_length=True)
# check if it takes 2nd sequence from the first batch
assert b[0] == a[1][0]
Expand Down

0 comments on commit e8883e4

Please sign in to comment.