Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix InferShape pass
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx committed Mar 1, 2019
1 parent 992c3c0 commit 29ed75d
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 6 deletions.
70 changes: 64 additions & 6 deletions src/executor/infer_graph_attr_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,26 @@ bool ApplyOpInferAttr<int, FInferStorageType>(const nnvm::Graph& g,
* shape/type inference functions'. The nnvm InferAttr will be deprecated
* in the future. Please use interfaces InferShape, InferType, and InferStorageType
* to call this function.
*
* \param ret graph used for attribute inference
* \param emmpty_val empty value of the attribute
* \param infer_name name of the function used for attribute inference
* \param input_name name of the attribute in the graph used to store the
* input data for attribute inference
* \param attr_key_name name of the attribute used for inference for variable nodes
* \param attr_name name of the inferred attribute
* \param unknown_name name of the attribute storing number of entries
* impossible to infer
* \param fis_none function returning true for not fully inferred values
* \param fdefault default function used for inference if the node does not
* provide its own implementation.
* \param bwd_identity_assign whether the attributes of forward NDArray and backward
* NDArray have to be the same. False only for storage
* type inference
* \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for
* storage type inference
* \param default_mode_val default value of the dispatch mode attribute on the node. Used
* for storage type inference
*/
template<typename AttrType, typename FInferType, typename IsNone, typename FDefault>
nnvm::Graph InferAttr(nnvm::Graph &&ret,
Expand Down Expand Up @@ -322,7 +342,32 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
return ret;
}

template<typename IsNone, typename FDefault>
/*!\brief
* This is a version of the InferAttr function specifically for shape inference.
*
* \param ret graph used for attribute inference
* \param emmpty_val empty value of the attribute
* \param infer_name name of the function used for attribute inference
* \param input_name name of the attribute in the graph used to store the
* input data for attribute inference
* \param attr_key_name name of the attribute used for inference for variable nodes
* \param attr_name name of the inferred attribute
* \param unknown_name name of the attribute storing number of entries
* impossible to infer
* \param fis_none function returning true for not fully inferred values
* \param fnum_unknown function returning how many elements are unknown in
* partially inferred value of the attribute
* \param fdefault default function used for inference if the node does not
* provide its own implementation.
* \param bwd_identity_assign whether the attributes of forward NDArray and backward
* NDArray have to be the same. False only for storage
* type inference
* \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for
* storage type inference
* \param default_mode_val default value of the dispatch mode attribute on the node. Used
* for storage type inference
*/
template<typename IsNone, typename FDefault, typename FNumUnknown>
nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
const nnvm::TShape empty_val,
const char* infer_name,
Expand All @@ -331,6 +376,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
const char* attr_name,
const char* unknown_name,
IsNone fis_none,
FNumUnknown fnum_unknown,
FDefault fdefault,
bool bwd_identity_assign,
const char* dispatch_mode_name,
Expand Down Expand Up @@ -548,12 +594,12 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
};

size_t last_num_unknown;
size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0;
size_t num_unknown_entry_attr = entry_end - entry_start;
size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode;
size_t num_unknown = static_cast<size_t>(-1); // Infinity

int i = 0;
do {
if (i % 2 == 0) {
// forward inference
for (uint32_t nid = node_start; nid < node_end; ++nid) {
infer_step(nid, false);
}
Expand All @@ -567,7 +613,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
num_unknown = 0;
for (size_t j = entry_start; j < entry_end; ++j) {
if (fis_none(rshape[j])) {
++num_unknown;
num_unknown += fnum_unknown(rshape[j]);
}
}
if (dispatch_mode_name) {
Expand Down Expand Up @@ -598,11 +644,23 @@ nnvm::Graph InferShape(nnvm::Graph&& graph,
if (shape_attr_key.length() != 0) {
graph.attrs["shape_attr_key"] = std::make_shared<any>(shape_attr_key);
}
return InferAttr<mxnet::TShape, mxnet::FInferShape>(
return InferShapeAttr(
std::move(graph), mxnet::TShape(),
"FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes",
[](const mxnet::TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
[](const mxnet::TShape& s) {
if (s.ndim() == 0) { // TODO(reminisce): Usage of ndim
return static_cast<size_t>(1);
}
size_t ret = 0;
for (const auto& val : s) {
if (val == 0) {
++ret;
}
}
return ret;
},
nullptr, true, nullptr);
}

Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,19 @@ def test_symbol_infer_shape():
assert arg_shapes['x2h_weight'] == (num_hidden, num_dim)
assert arg_shapes['h2h_weight'] == (num_hidden, num_hidden)

# Partial shape inference with some unknown dimensions
data_shape = (1, 0, 0, 0)
data = mx.sym.Variable('data', shape=data_shape)
weight = mx.sym.Variable('weight')
cdata = mx.sym.cast(data, dtype='float16')
cweight = mx.sym.cast(weight, dtype='float16')
test = mx.sym.Convolution(data=cdata, weight=cweight, pad=(3, 3), num_filter=64, stride=(2, 2), no_bias=True, kernel=(7, 7))

arg, _, _ = test.infer_shape_partial()
arg_shapes = dict(zip(test.list_arguments(), arg))
assert arg_shapes['data'] == data_shape
assert arg_shapes['weight'] == (64, 0, 7, 7)


def test_symbol_infer_shape_var():
"Test specifying shape information when constructing a variable"
Expand Down

0 comments on commit 29ed75d

Please sign in to comment.