diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index af8094ad92af..6a7fde62c2cf 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -68,6 +68,26 @@ bool ApplyOpInferAttr(const nnvm::Graph& g, * shape/type inference functions'. The nnvm InferAttr will be deprecated * in the future. Please use interfaces InferShape, InferType, and InferStorageType * to call this function. + * + * \param ret graph used for attribute inference + * \param emmpty_val empty value of the attribute + * \param infer_name name of the function used for attribute inference + * \param input_name name of the attribute in the graph used to store the + * input data for attribute inference + * \param attr_key_name name of the attribute used for inference for variable nodes + * \param attr_name name of the inferred attribute + * \param unknown_name name of the attribute storing number of entries + * impossible to infer + * \param fis_none function returning true for not fully inferred values + * \param fdefault default function used for inference if the node does not + * provide its own implementation. + * \param bwd_identity_assign whether the attributes of forward NDArray and backward + * NDArray have to be the same. False only for storage + * type inference + * \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for + * storage type inference + * \param default_mode_val default value of the dispatch mode attribute on the node. Used + * for storage type inference */ template nnvm::Graph InferAttr(nnvm::Graph &&ret, @@ -322,23 +342,49 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, return ret; } -template +/*!\brief + * This is a version of the InferAttr function specifically for shape inference. + * + * \param ret graph used for attribute inference + * \param emmpty_val empty value of the attribute + * \param infer_name name of the function used for attribute inference + * \param input_name name of the attribute in the graph used to store the + * input data for attribute inference + * \param attr_key_name name of the attribute used for inference for variable nodes + * \param attr_name name of the inferred attribute + * \param unknown_name name of the attribute storing number of entries + * impossible to infer + * \param fis_none function returning true for not fully inferred values + * \param fnum_unknown function returning how many elements are unknown in + * partially inferred value of the attribute + * \param fdefault default function used for inference if the node does not + * provide its own implementation. + * \param bwd_identity_assign whether the attributes of forward NDArray and backward + * NDArray have to be the same. False only for storage + * type inference + * \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for + * storage type inference + * \param default_mode_val default value of the dispatch mode attribute on the node. Used + * for storage type inference + */ +template nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, - const nnvm::TShape empty_val, + const mxnet::TShape empty_val, const char* infer_name, const char* input_name, const char* attr_key_name, const char* attr_name, const char* unknown_name, IsNone fis_none, + FNumUnknown fnum_unknown, FDefault fdefault, bool bwd_identity_assign, const char* dispatch_mode_name, const DispatchMode default_mode_val = DispatchMode::kUndefined) { using nnvm::IndexedGraph; using nnvm::Op; - using AttrType = nnvm::TShape; - using FInferType = nnvm::FInferShape; + using AttrType = mxnet::TShape; + using FInferType = mxnet::FInferShape; using AttrVector = std::vector; using NodeAttrVector = std::vector; using dmlc::any; @@ -548,12 +594,12 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, }; size_t last_num_unknown; - size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0; - size_t num_unknown_entry_attr = entry_end - entry_start; - size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode; + size_t num_unknown = static_cast(-1); // Infinity + int i = 0; do { if (i % 2 == 0) { + // forward inference for (uint32_t nid = node_start; nid < node_end; ++nid) { infer_step(nid, false); } @@ -567,7 +613,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, num_unknown = 0; for (size_t j = entry_start; j < entry_end; ++j) { if (fis_none(rshape[j])) { - ++num_unknown; + num_unknown += fnum_unknown(rshape[j]); } } if (dispatch_mode_name) { @@ -598,11 +644,23 @@ nnvm::Graph InferShape(nnvm::Graph&& graph, if (shape_attr_key.length() != 0) { graph.attrs["shape_attr_key"] = std::make_shared(shape_attr_key); } - return InferAttr( + return InferShapeAttr( std::move(graph), mxnet::TShape(), "FInferShape", "shape_inputs", "shape_attr_key", "shape", "shape_num_unknown_nodes", [](const mxnet::TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, + [](const mxnet::TShape& s) { + if (s.ndim() == 0) { // TODO(reminisce): Usage of ndim + return static_cast(1); + } + size_t ret = 0; + for (const auto& val : s) { + if (val == 0) { + ++ret; + } + } + return ret; + }, nullptr, true, nullptr); } diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index ac4564b66fa0..b290ff344227 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -157,6 +157,19 @@ def test_symbol_infer_shape(): assert arg_shapes['x2h_weight'] == (num_hidden, num_dim) assert arg_shapes['h2h_weight'] == (num_hidden, num_hidden) + # Partial shape inference with some unknown dimensions + data_shape = (1, 0, 0, 0) + data = mx.sym.Variable('data', shape=data_shape) + weight = mx.sym.Variable('weight') + cdata = mx.sym.cast(data, dtype='float16') + cweight = mx.sym.cast(weight, dtype='float16') + test = mx.sym.Convolution(data=cdata, weight=cweight, pad=(3, 3), num_filter=64, stride=(2, 2), no_bias=True, kernel=(7, 7)) + + arg, _, _ = test.infer_shape_partial() + arg_shapes = dict(zip(test.list_arguments(), arg)) + assert arg_shapes['data'] == data_shape + assert arg_shapes['weight'] == (64, 0, 7, 7) + def test_symbol_infer_shape_var(): "Test specifying shape information when constructing a variable"