diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 4b9814e46e6d..0cb296bd61f9 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -30,6 +30,7 @@ #include "nnvm/pass_functions.h" #include "nnvm/symbolic.h" #include "./c_api_common.h" +#include "../common/exec_utils.h" #include "../operator/operator_common.h" #include "../executor/exec_pass.h" #include "../operator/subgraph/subgraph_property.h" @@ -1213,8 +1214,9 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, arg_dtypes.push_back(in_arg.dtype()); arg_stypes.push_back(in_arg.storage_type()); in_arg_ctxes[i] = in_arg.ctx(); - orig_g = InferForwardAttrs(orig_g, arg_shapes, arg_dtypes, arg_stypes, - default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes, true); + } + orig_g = common::InferForwardAttrs(orig_g, arg_shapes, arg_dtypes, arg_stypes, default_ctx, + ctx_map, in_arg_ctxes, aux_state_ctxes, true); } std::vector> options_map; for (mx_uint i = 0; i < num_options; ++i) { diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h index d8b7a33bf22b..8377744b4aee 100644 --- a/src/common/exec_utils.h +++ b/src/common/exec_utils.h @@ -621,6 +621,42 @@ inline nnvm::Graph AssignContext(nnvm::Graph g, return g; } +/*! + * \brief infers shapes, dtypes, stypes, contexts for the forward graph + */ +inline nnvm::Graph InferForwardAttrs(nnvm::Graph g, + mxnet::ShapeVector arg_shapes, + nnvm::DTypeVector arg_dtypes, + StorageTypeVector arg_stypes, + const Context& default_ctx, + const std::map& ctx_map, + const std::vector& in_arg_ctxes, + const std::vector& aux_state_ctxes, + bool partial_shape = false) { + const auto& indexed_graph = g.indexed_graph(); + const auto num_forward_inputs = indexed_graph.input_nodes().size(); + g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {}, + aux_state_ctxes, {}, num_forward_inputs, g.outputs.size()); + g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); + if (g.GetAttr("shape_num_unknown_nodes") != 0U) { + if (!partial_shape) { + HandleInferShapeError(num_forward_inputs, indexed_graph, + g.GetAttr("shape")); + } + } + g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); + if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { + HandleInferTypeError(num_forward_inputs, indexed_graph, + g.GetAttr("dtype")); + } + g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); + if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { + HandleInferStorageTypeError(num_forward_inputs, indexed_graph, + g.GetAttr("storage_type")); + } + return g; +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_EXEC_UTILS_H_ diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index c5da0a0cf90d..0a3311f1a055 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1602,40 +1602,6 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, return ret; } -// Infer shapes, dtypes, stypes, contexts for the forward graph -static nnvm::Graph InferForwardAttrs(nnvm::Graph g, - mxnet::ShapeVector arg_shapes, - nnvm::DTypeVector arg_dtypes, - StorageTypeVector arg_stypes, - const Context& default_ctx, - const std::map& ctx_map, - const std::vector& in_arg_ctxes, - const std::vector& aux_state_ctxes, - bool partial_shape = false) { - const auto& indexed_graph = g.indexed_graph(); - const auto num_forward_inputs = indexed_graph.input_nodes().size(); - g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {}, - aux_state_ctxes, {}, num_forward_inputs, g.outputs.size()); - g = InferShape(std::move(g), std::move(arg_shapes), "__shape__"); - if (g.GetAttr("shape_num_unknown_nodes") != 0U) { - if (!partial_shape) { - HandleInferShapeError(num_forward_inputs, indexed_graph, - g.GetAttr("shape")); - } - } - g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); - if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { - HandleInferTypeError(num_forward_inputs, indexed_graph, - g.GetAttr("dtype")); - } - g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); - if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { - HandleInferStorageTypeError(num_forward_inputs, indexed_graph, - g.GetAttr("storage_type")); - } - return g; -} - static bool SubgraphBackendCheck(const op::SubgraphBackendPtr& backend, const Context& default_ctx, bool verbose = false) { diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index 7d87043c8196..f5c8632fd0a6 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -26,6 +26,7 @@ #include #include #include +#include namespace mxnet { namespace op {