Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ulmasov committed Feb 3, 2021
1 parent 65beccd commit 9319597
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/operator/sequence_last-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ class SequenceLastOp : public Operator {
auto d1 = in_data[seq_last::kData].size(1);
auto dsize = in_data[seq_last::kData].Size();

if (dsize == 0) {
return; // noop if any input dimension is zero-sized, out_data is of a right shape
}

auto batch = (axis != 0) ? d0 : d1;
auto max_seq_len = in_data[seq_last::kData].size(axis);
auto rest_size = dsize / (d0 * d1);
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9428,3 +9428,11 @@ def test_sldwin_atten_op_impl(batch_size, seq_length, num_heads,
test_sldwin_atten_op_impl(2, 128, 2, 8, 16, symmetric, d)
test_sldwin_atten_op_impl(1, 8, 2, 4, 2, symmetric, d)

def test_zero_sized_dim():
"""Test for issue: https://github.com/apache/incubator-mxnet/issues/18938"""
mx.util.set_np_shape(True) # Must be done to prevent zero-sized dimension conversion to 'unknown'
data = mx.nd.array(np.random.rand(1, 0, 0))
res = mx.nd.op.SequenceLast(data)
assert data.shape[1:] == res.shape
assert len(res) == 0

0 comments on commit 9319597

Please sign in to comment.