From 5ac2371bf39d0532d8425e40dec9bae577016857 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Thu, 6 Sep 2018 11:36:59 -0700 Subject: [PATCH 1/2] allow foreach on input with 0 length --- src/operator/control_flow.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index d6b6703ddd58..ba7f5c0ad8b2 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -314,7 +314,6 @@ static bool ForeachShape(const nnvm::NodeAttrs& attrs, // For the shape of output data. size_t len = in_shape->at(0)[0]; - CHECK_GT(len, 0); for (int i = 0; i < params.num_out_data; i++) { // If the output shape isn't inferred, we don't need to propogate the info. const auto& g_out_shape = subg_out_shape[i]; From 5c0829bc8d66bec02d688b6272b5ee026efb1db2 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Fri, 7 Sep 2018 11:58:27 -0700 Subject: [PATCH 2/2] add test foreach with unknown dim --- tests/python/unittest/test_contrib_control_flow.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 1c23c9161977..dd5a4d6d3152 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -2146,6 +2146,15 @@ def func3(data): for i in range(len(out1)): assert_almost_equal(out1[i].asnumpy(), out2[i].asnumpy(), rtol=0.001, atol=0.0001) +def test_foreach_with_unkown_dim(): + # MXNet supports using 0 as placeholder for unknown dimensions in shape + step = lambda data, states: (data + states[0], [states[0] * 2]) + # input shape with NCHW format and N is unknown + data = mx.sym.var('data', shape=(0, 3, 32, 32)) + states = [mx.sym.var('state')] + outs, states = mx.sym.contrib.foreach(step, data, states) + _, output_shape, _ = outs.infer_shape_partial() + assert_allclose((0, 3, 32, 32), output_shape[0]) if __name__ == '__main__': import nose