Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add checks for shape/type inferences
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Jun 3, 2017
1 parent 3a43edf commit 1115d19
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,13 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
arg_dtypes.resize(idx.input_nodes().size(), -1);
// Infer shapes and dtypes
g = nnvm::pass::InferShape(g, arg_shapes, "__shape__");
CHECK_EQ(g.GetAttr<size_t>("shape_num_unknown_nodes"), 0)
<< "Shape inference failed in bind. Please provide"
" sufficient shapes to make inference for the symbol";
g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__");
CHECK_EQ(g.GetAttr<size_t>("dtype_num_unknown_nodes"), 0)
<< "Type inference failed in bind. Please provide"
" sufficcient types to make inference for the symbol";

// Initialize the rest attributes of the graph.
// This function can be called by regular bind
Expand Down Expand Up @@ -738,7 +744,13 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
}
}
g = nnvm::pass::InferShape(g, arg_shapes, "__shape__");
CHECK_EQ(g.GetAttr<size_t>("shape_num_unknown_nodes"), 0)
<< "Shape inference failed in simple_bind. Please provide"
" sufficient shapes to make inference for the symbol";
g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__");
CHECK_EQ(g.GetAttr<size_t>("dtype_num_unknown_nodes"), 0)
<< "Type inference failed in simple_bind. Please provide"
" sufficcient types to make inference for the symbol";

// Create in_args, arg_grads, and aux_states using
// the inferred shapes and dtypes.
Expand Down

0 comments on commit 1115d19

Please sign in to comment.