Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add printing error messages for shape/type inference failure
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Jun 2, 2017
1 parent 5949476 commit 66be337
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 49 deletions.
68 changes: 37 additions & 31 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,37 +1336,43 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
num_aux_states = ctypes.c_uint()
aux_state_handles = ctypes.POINTER(NDArrayHandle)()

check_call(_LIB.MXExecutorSimpleBind(self.handle,
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
num_ctx_map_keys,
ctx_map_keys,
ctx_map_dev_types,
ctx_map_dev_ids,
mx_uint(provided_req_type_list_len),
provided_grad_req_names,
provided_grad_req_types,
mx_uint(len(provided_arg_shape_names)),
c_array(ctypes.c_char_p, provided_arg_shape_names),
c_array(mx_uint, provided_arg_shape_data),
c_array(mx_uint, provided_arg_shape_idx),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
mx_uint(len(shared_arg_name_list)),
c_array(ctypes.c_char_p, shared_arg_name_list),
ctypes.byref(shared_buffer_len),
shared_buffer_names,
shared_buffer_handles,
ctypes.byref(updated_shared_buffer_names),
ctypes.byref(updated_shared_buffer_handles),
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
ctypes.byref(num_aux_states),
ctypes.byref(aux_state_handles),
shared_exec_handle,
ctypes.byref(exe_handle)))
try:
check_call(_LIB.MXExecutorSimpleBind(self.handle,
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
num_ctx_map_keys,
ctx_map_keys,
ctx_map_dev_types,
ctx_map_dev_ids,
mx_uint(provided_req_type_list_len),
provided_grad_req_names,
provided_grad_req_types,
mx_uint(len(provided_arg_shape_names)),
c_array(ctypes.c_char_p, provided_arg_shape_names),
c_array(mx_uint, provided_arg_shape_data),
c_array(mx_uint, provided_arg_shape_idx),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
mx_uint(len(shared_arg_name_list)),
c_array(ctypes.c_char_p, shared_arg_name_list),
ctypes.byref(shared_buffer_len),
shared_buffer_names,
shared_buffer_handles,
ctypes.byref(updated_shared_buffer_names),
ctypes.byref(updated_shared_buffer_handles),
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
ctypes.byref(num_aux_states),
ctypes.byref(aux_state_handles),
shared_exec_handle,
ctypes.byref(exe_handle)))
except MXNetError:
print("simple_bind error. Arguments:")
for k, v in kwargs.items():
print(" %s: %s" % (k, v))
raise RuntimeError('simple_bind failed')

# update shared_buffer
if shared_buffer is not None:
Expand Down
87 changes: 69 additions & 18 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,53 @@ Graph AssignContext(Graph g,
return g;
}

void HandleInferShapeError(const size_t num_forward_inputs,
const nnvm::IndexedGraph& idx,
const nnvm::ShapeVector& inferred_shapes) {
int cnt = 10;
std::ostringstream oss;
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = idx.input_nodes().at(i);
const uint32_t eid = idx.entry_id(nid, 0);
const TShape& inferred_shape = inferred_shapes[eid];
if (inferred_shape.ndim() == 0 || inferred_shape.Size() == 0U) {
const std::string& arg_name = idx[nid].source->attrs.name;
oss << arg_name << ": " << inferred_shape << ", ";
if (--cnt == 0) {
oss << "...";
break;
}
}
}
LOG(FATAL) << "InferShape pass cannot decide shapes for the following arguments "
"(0s in shapes mean unknown dimension size). Please consider "
"providing them as inputs:\n"
<< oss.str();
}

void HandleInferTypeError(const size_t num_forward_inputs,
const nnvm::IndexedGraph& idx,
const nnvm::DTypeVector& inferred_dtypes) {
int cnt = 10;
std::ostringstream oss;
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = idx.input_nodes().at(i);
const uint32_t eid = idx.entry_id(nid, 0);
const int inferred_dtype = inferred_dtypes[eid];
if (inferred_dtype == -1) {
const std::string& arg_name = idx[nid].source->attrs.name;
oss << arg_name << ": " << inferred_dtype << ", ";
if (--cnt == 0) {
oss << "...";
break;
}
}
}
LOG(FATAL) << "InferType pass cannot decide dtypes for the following arguments "
"(-1 means unknown dtype). Please consider providing them as inputs:\n"
<< oss.str();
}

/*!
* \brief GraphExecutor initializer for regular bind flow in which
* input arguments and gradients are provided by users. This initializer
Expand Down Expand Up @@ -391,7 +438,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,

// create arg_shapes and arg_dtypes for shape and type inferences
const auto& idx = g.indexed_graph();
auto mutable_nodes = idx.mutable_input_nodes();
const auto& mutable_nodes = idx.mutable_input_nodes();
size_t arg_top = 0, aux_top = 0;
data_entry_.resize(idx.num_node_entries());
nnvm::ShapeVector arg_shapes;
Expand Down Expand Up @@ -422,16 +469,18 @@ void GraphExecutor::Init(nnvm::Symbol symbol,

// expand arg_shapes and arg_dtypes to contain backward inputs
arg_shapes.resize(idx.input_nodes().size(), TShape());
arg_dtypes.resize(idx.input_nodes().size(), -1);
// Infer shapes and dtypes
g = nnvm::pass::InferShape(g, arg_shapes, "__shape__");
CHECK_EQ(g.GetAttr<size_t>("shape_num_unknown_nodes"), 0)
<< "Shape inference failed in bind. Please provide"
" sufficient shapes to make inference for the symbol";
if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
g.GetAttr<nnvm::ShapeVector>("shape"));
}

arg_dtypes.resize(idx.input_nodes().size(), -1);
g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__");
CHECK_EQ(g.GetAttr<size_t>("dtype_num_unknown_nodes"), 0)
<< "Type inference failed in bind. Please provide"
" sufficcient types to make inference for the symbol";
if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
g.GetAttr<nnvm::DTypeVector>("dtype"));
}

// Initialize the rest attributes of the graph.
// This function can be called by regular bind
Expand Down Expand Up @@ -459,8 +508,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
// populate grad_store_
data_entry_.resize(idx.num_node_entries());
size_t arg_top = 0, aux_top = 0;
auto mutable_nodes = idx.mutable_input_nodes();
// TODO(junwu): populate in_arg_map, arg_grad_map, and aux_state_map
const auto& mutable_nodes = idx.mutable_input_nodes();
for (size_t i = 0; i < num_forward_inputs_; ++i) {
const uint32_t nid = idx.input_nodes().at(i);
const uint32_t eid = idx.entry_id(nid, 0);
Expand Down Expand Up @@ -545,7 +593,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
// initialize in_args, arg_grads, and aux_states and populate grad_store_
data_entry_.resize(idx.num_node_entries());
size_t arg_top = 0, aux_top = 0;
auto mutable_nodes = idx.mutable_input_nodes();
const auto& mutable_nodes = idx.mutable_input_nodes();
for (size_t i = 0; i < num_forward_inputs_; ++i) {
const uint32_t nid = idx.input_nodes().at(i);
const uint32_t eid = idx.entry_id(nid, 0);
Expand Down Expand Up @@ -744,13 +792,16 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
}
}
g = nnvm::pass::InferShape(g, arg_shapes, "__shape__");
CHECK_EQ(g.GetAttr<size_t>("shape_num_unknown_nodes"), 0)
<< "Shape inference failed in simple_bind. Please provide"
" sufficient shapes to make inference for the symbol";
if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
g.GetAttr<nnvm::ShapeVector>("shape"));
}

g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__");
CHECK_EQ(g.GetAttr<size_t>("dtype_num_unknown_nodes"), 0)
<< "Type inference failed in simple_bind. Please provide"
" sufficcient types to make inference for the symbol";
if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
g.GetAttr<nnvm::DTypeVector>("dtype"));
}

// Create in_args, arg_grads, and aux_states using
// the inferred shapes and dtypes.
Expand Down

0 comments on commit 66be337

Please sign in to comment.