Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[BUGFIX] fix #18938, #18939, #18940 (#19833)
Browse files Browse the repository at this point in the history
* fix #18938

* fix #18939, #18940

Co-authored-by: r3stl355 <[email protected]>
  • Loading branch information
r3stl355 and ulmasov authored Feb 7, 2021
1 parent c459127 commit 3dd678f
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/operator/sequence_last-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ class SequenceLastOp : public Operator {
auto d1 = in_data[seq_last::kData].size(1);
auto dsize = in_data[seq_last::kData].Size();

if (dsize == 0) {
return; // noop if any input dimension is zero-sized, out_data is of a right shape
}

auto batch = (axis != 0) ? d0 : d1;
auto max_seq_len = in_data[seq_last::kData].size(axis);
auto rest_size = dsize / (d0 * d1);
Expand Down
5 changes: 5 additions & 0 deletions src/operator/sequence_mask-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ class SequenceMaskOp : public Operator {
auto d0 = in_data[seq_mask::kData].size(0);
auto d1 = in_data[seq_mask::kData].size(1);
auto dsize = in_data[seq_mask::kData].Size();

if (dsize == 0) {
return; // noop if any input dimension is zero-sized, out_data is of a right shape
}

auto rest_size = dsize / (d0 * d1);

Shape<3> s3 = Shape3(d0, d1, rest_size);
Expand Down
5 changes: 5 additions & 0 deletions src/operator/sequence_reverse-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ class SequenceReverseOp : public Operator {
auto max_seq_len = in_data[seq_reverse::kData].size(0);
auto n = in_data[seq_reverse::kData].size(1);
auto total_size = in_data[seq_reverse::kData].Size();

if (total_size == 0) {
return; // noop if any input dimension is zero-sized, out_data is of a right shape
}

auto rest_dim = static_cast<int>(total_size / n / max_seq_len);

Shape<3> s3 = Shape3(max_seq_len, n, rest_dim);
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9428,3 +9428,28 @@ def test_sldwin_atten_op_impl(batch_size, seq_length, num_heads,
test_sldwin_atten_op_impl(2, 128, 2, 8, 16, symmetric, d)
test_sldwin_atten_op_impl(1, 8, 2, 4, 2, symmetric, d)

def test_zero_sized_dim():

mx.util.set_np_shape(True) # Must be done to prevent zero-sized dimension conversion to 'unknown'

def seq_last():
"""Test for issue: https://github.com/apache/incubator-mxnet/issues/18938"""
data = mx.nd.array(np.random.rand(1, 0, 0))
res = mx.nd.op.SequenceLast(data)
assert data.shape[1:] == res.shape

def seq_mask():
"""Test for issue: https://github.com/apache/incubator-mxnet/issues/18939"""
data = mx.nd.array(np.random.rand(0, 1, 1))
res = mx.nd.op.SequenceMask(data)
assert data.shape == res.shape

def seq_reverse():
"""Test for issue: https://github.com/apache/incubator-mxnet/issues/18940"""
data = mx.nd.array(np.random.rand(0, 1, 1))
res = mx.nd.op.SequenceReverse(data)
assert data.shape == res.shape

seq_last()
seq_reverse()
seq_mask()

0 comments on commit 3dd678f

Please sign in to comment.