From 1115d19c5c85deb7fc328b28fc4612d79073d742 Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 1 Jun 2017 20:32:22 -0700 Subject: [PATCH] Add checks for shape/type inferences --- src/executor/graph_executor.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 64e2410bedb2..20135d33dee6 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -425,7 +425,13 @@ void GraphExecutor::Init(nnvm::Symbol symbol, arg_dtypes.resize(idx.input_nodes().size(), -1); // Infer shapes and dtypes g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0) + << "Shape inference failed in bind. Please provide" + " sufficient shapes to make inference for the symbol"; g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + CHECK_EQ(g.GetAttr("dtype_num_unknown_nodes"), 0) + << "Type inference failed in bind. Please provide" + " sufficcient types to make inference for the symbol"; // Initialize the rest attributes of the graph. // This function can be called by regular bind @@ -738,7 +744,13 @@ void GraphExecutor::Init(nnvm::Symbol symbol, } } g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0) + << "Shape inference failed in simple_bind. Please provide" + " sufficient shapes to make inference for the symbol"; g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + CHECK_EQ(g.GetAttr("dtype_num_unknown_nodes"), 0) + << "Type inference failed in simple_bind. Please provide" + " sufficcient types to make inference for the symbol"; // Create in_args, arg_grads, and aux_states using // the inferred shapes and dtypes.