diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc index 413395c3b74f..53c4d73a3537 100644 --- a/src/operator/subgraph/build_subgraph.cc +++ b/src/operator/subgraph/build_subgraph.cc @@ -226,9 +226,7 @@ bool LabelSubgraph(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector std::stack s; s.push(descendant); size_t count = 0; - while (!s.empty()) { - CHECK_LT(count, indexed_graph.num_nodes()) << "Finding ancestor failed. There is probably" - " a loop in the graph"; + while (!s.empty() && count < indexed_graph.num_nodes()) { ++count; const nnvm::Node* top = s.top(); s.pop(); @@ -276,10 +274,6 @@ bool LabelSubgraph(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector if (excluded_node_id != -1) { CHECK_LT(excluded_node_id, static_cast(simple_nodes.size())); - CHECK_NE(excluded_node_id, static_cast(snid)) - << "A cycle is found in the computational graph between nodes " - << simple_nodes[excluded_node_id]->node->attrs.name << " and " - << simple_nodes[snid]->node->attrs.name; excluded_nodes->insert(simple_nodes[excluded_node_id].get()); ResetNodeLabels(g, simple_nodes, subgraph_nodes); return false; @@ -306,6 +300,7 @@ void PreSelectSubgraphNodes(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph const std::vector& simple_nodes, std::vector* subgraph_nodes) { std::unordered_set excluded_nodes; + size_t n_excluded_nodes = 0; const size_t max_num_retry = simple_nodes.size() * simple_nodes.size(); size_t count = 0; bool success = false; @@ -313,7 +308,11 @@ void PreSelectSubgraphNodes(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph success = LabelSubgraph(g, subgraph_selector, label, snid, simple_nodes, subgraph_nodes, &excluded_nodes); if (!success) { - CHECK(!excluded_nodes.empty()); + if (excluded_nodes.size() == n_excluded_nodes) { + // no possible subgraph for the current node snid + break; + } + n_excluded_nodes = excluded_nodes.size(); std::string excluded_node_names; for (auto node : excluded_nodes) { excluded_node_names += node->node->attrs.name + ", ";