Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix shape inference pass #14153

Merged
merged 5 commits into from
Mar 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 67 additions & 9 deletions src/executor/infer_graph_attr_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,26 @@ bool ApplyOpInferAttr<int, FInferStorageType>(const nnvm::Graph& g,
* shape/type inference functions'. The nnvm InferAttr will be deprecated
* in the future. Please use interfaces InferShape, InferType, and InferStorageType
* to call this function.
*
* \param ret graph used for attribute inference
* \param emmpty_val empty value of the attribute
* \param infer_name name of the function used for attribute inference
* \param input_name name of the attribute in the graph used to store the
* input data for attribute inference
* \param attr_key_name name of the attribute used for inference for variable nodes
* \param attr_name name of the inferred attribute
* \param unknown_name name of the attribute storing number of entries
* impossible to infer
* \param fis_none function returning true for not fully inferred values
* \param fdefault default function used for inference if the node does not
* provide its own implementation.
* \param bwd_identity_assign whether the attributes of forward NDArray and backward
* NDArray have to be the same. False only for storage
* type inference
* \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for
* storage type inference
* \param default_mode_val default value of the dispatch mode attribute on the node. Used
* for storage type inference
*/
template<typename AttrType, typename FInferType, typename IsNone, typename FDefault>
nnvm::Graph InferAttr(nnvm::Graph &&ret,
Expand Down Expand Up @@ -322,23 +342,49 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
return ret;
}

template<typename IsNone, typename FDefault>
/*!\brief
* This is a version of the InferAttr function specifically for shape inference.
*
* \param ret graph used for attribute inference
* \param emmpty_val empty value of the attribute
* \param infer_name name of the function used for attribute inference
* \param input_name name of the attribute in the graph used to store the
* input data for attribute inference
* \param attr_key_name name of the attribute used for inference for variable nodes
* \param attr_name name of the inferred attribute
* \param unknown_name name of the attribute storing number of entries
* impossible to infer
* \param fis_none function returning true for not fully inferred values
* \param fnum_unknown function returning how many elements are unknown in
* partially inferred value of the attribute
* \param fdefault default function used for inference if the node does not
* provide its own implementation.
* \param bwd_identity_assign whether the attributes of forward NDArray and backward
* NDArray have to be the same. False only for storage
* type inference
* \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for
* storage type inference
* \param default_mode_val default value of the dispatch mode attribute on the node. Used
* for storage type inference
*/
template<typename IsNone, typename FDefault, typename FNumUnknown>
nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
const nnvm::TShape empty_val,
const mxnet::TShape empty_val,
const char* infer_name,
const char* input_name,
const char* attr_key_name,
const char* attr_name,
const char* unknown_name,
IsNone fis_none,
FNumUnknown fnum_unknown,
FDefault fdefault,
bool bwd_identity_assign,
const char* dispatch_mode_name,
const DispatchMode default_mode_val = DispatchMode::kUndefined) {
using nnvm::IndexedGraph;
using nnvm::Op;
using AttrType = nnvm::TShape;
using FInferType = nnvm::FInferShape;
using AttrType = mxnet::TShape;
using FInferType = mxnet::FInferShape;
using AttrVector = std::vector<AttrType>;
using NodeAttrVector = std::vector<DispatchMode>;
using dmlc::any;
Expand Down Expand Up @@ -548,12 +594,12 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
};

size_t last_num_unknown;
size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0;
size_t num_unknown_entry_attr = entry_end - entry_start;
size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode;
size_t num_unknown = static_cast<size_t>(-1); // Infinity

int i = 0;
do {
if (i % 2 == 0) {
// forward inference
for (uint32_t nid = node_start; nid < node_end; ++nid) {
infer_step(nid, false);
}
Expand All @@ -567,7 +613,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
num_unknown = 0;
for (size_t j = entry_start; j < entry_end; ++j) {
if (fis_none(rshape[j])) {
++num_unknown;
num_unknown += fnum_unknown(rshape[j]);
}
}
if (dispatch_mode_name) {
Expand Down Expand Up @@ -598,11 +644,23 @@ nnvm::Graph InferShape(nnvm::Graph&& graph,
if (shape_attr_key.length() != 0) {
graph.attrs["shape_attr_key"] = std::make_shared<any>(shape_attr_key);
}
return InferAttr<mxnet::TShape, mxnet::FInferShape>(
return InferShapeAttr(
std::move(graph), mxnet::TShape(),
"FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes",
[](const mxnet::TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
[](const mxnet::TShape& s) {
if (s.ndim() == 0) { // TODO(reminisce): Usage of ndim
return static_cast<size_t>(1);
}
size_t ret = 0;
for (const auto& val : s) {
if (val == 0) {
++ret;
}
}
return ret;
},
nullptr, true, nullptr);
}

Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,19 @@ def test_symbol_infer_shape():
assert arg_shapes['x2h_weight'] == (num_hidden, num_dim)
assert arg_shapes['h2h_weight'] == (num_hidden, num_hidden)

# Partial shape inference with some unknown dimensions
data_shape = (1, 0, 0, 0)
data = mx.sym.Variable('data', shape=data_shape)
weight = mx.sym.Variable('weight')
cdata = mx.sym.cast(data, dtype='float16')
cweight = mx.sym.cast(weight, dtype='float16')
test = mx.sym.Convolution(data=cdata, weight=cweight, pad=(3, 3), num_filter=64, stride=(2, 2), no_bias=True, kernel=(7, 7))

arg, _, _ = test.infer_shape_partial()
arg_shapes = dict(zip(test.list_arguments(), arg))
assert arg_shapes['data'] == data_shape
assert arg_shapes['weight'] == (64, 0, 7, 7)


def test_symbol_infer_shape_var():
"Test specifying shape information when constructing a variable"
Expand Down