Skip to content

Commit

Permalink
Faster pointwise fusion graph pass (apache#19269)
Browse files Browse the repository at this point in the history
* Faster pointwise fusion graph pass

* Fix lint

* Fix lint 2

* Fixes

* Fixing slice parameter handling in fusion

* Fixing the slice fix

* Fix the cycle bug

* Added test

* Fix lint

* Fix merging of subgraphs

* Fixes from review
  • Loading branch information
ptrendx authored Oct 19, 2020
1 parent e6cea0d commit a0fd1fe
Show file tree
Hide file tree
Showing 7 changed files with 805 additions and 668 deletions.
5 changes: 1 addition & 4 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,7 @@ void OptimizeGraph(nnvm::Graph* full_graph, nnvm::Graph* fwd_graph, nnvm::Graph*
common::CopyGraph(&unoptimized_graph, *full_graph, false);

if (common::CheckForInputNameDuplicates(unoptimized_graph.indexed_graph())) {
full_graph->attrs["num_forward_outputs"] = std::make_shared<nnvm::any>(num_forward_outputs);
*full_graph = exec::FusePointwiseForward(std::move(*full_graph));
full_graph->attrs["num_forward_outputs"] = std::make_shared<nnvm::any>(num_forward_outputs);
*full_graph = exec::FusePointwiseBackward(std::move(*full_graph));
*full_graph = exec::FusePointwise(*full_graph, num_forward_outputs);
// Fill in input_map - mapping from the new to the original input indices.
const auto &original_inputs = unoptimized_graph.indexed_graph().input_nodes();
const auto &new_inputs = full_graph->indexed_graph().input_nodes();
Expand Down
16 changes: 4 additions & 12 deletions src/imperative/exec_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,22 +204,14 @@ Graph DetectInplaceAddTo(Graph g);
Graph EliminateCommonExpr(Graph && g);

/*!
* \brief Fuse pointwise operations in the forward pass.
* \brief Fuse pointwise operations in the graph.
*
* \param g input graph (needs to be entire graph, not just forward part)
* \param num_forward_outputs number of outputs in the graph produced by the forward pass
*
* \return graph with fused pointwise operations in the forward pass
* \return copy of the graph with fused pointwise operations
*/
Graph FusePointwiseForward(Graph&& g);

/*!
* \brief Fuse pointwise operations in the backward pass.
*
* \param g input graph (needs to be entire graph, not just forward part)
*
* \return graph with fused pointwise operations in the backward pass
*/
Graph FusePointwiseBackward(Graph&& g);
Graph FusePointwise(const Graph& g, const size_t num_forward_outputs);

/*!
* \brief Issue a one-time warning that fusion is not possible for this platform or build.
Expand Down
Loading

0 comments on commit a0fd1fe

Please sign in to comment.