Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
allow foreach on input with 0 length (#12471)
Browse files Browse the repository at this point in the history
* allow foreach on input with 0 length

* add test foreach with unknown dim
  • Loading branch information
roywei authored and sandeep-krishnamurthy committed Sep 8, 2018
1 parent 445967e commit 4eb7626
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
1 change: 0 additions & 1 deletion src/operator/control_flow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ static bool ForeachShape(const nnvm::NodeAttrs& attrs,

// For the shape of output data.
size_t len = in_shape->at(0)[0];
CHECK_GT(len, 0);
for (int i = 0; i < params.num_out_data; i++) {
// If the output shape isn't inferred, we don't need to propogate the info.
const auto& g_out_shape = subg_out_shape[i];
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_contrib_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2146,6 +2146,15 @@ def func3(data):
for i in range(len(out1)):
assert_almost_equal(out1[i].asnumpy(), out2[i].asnumpy(), rtol=0.001, atol=0.0001)

def test_foreach_with_unkown_dim():
# MXNet supports using 0 as placeholder for unknown dimensions in shape
step = lambda data, states: (data + states[0], [states[0] * 2])
# input shape with NCHW format and N is unknown
data = mx.sym.var('data', shape=(0, 3, 32, 32))
states = [mx.sym.var('state')]
outs, states = mx.sym.contrib.foreach(step, data, states)
_, output_shape, _ = outs.infer_shape_partial()
assert_allclose((0, 3, 32, 32), output_shape[0])

if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 4eb7626

Please sign in to comment.