Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Eliminate common expressions (#15657)
Browse files Browse the repository at this point in the history
* Eliminate common expressions from a graph

* Guarding against optimizing out stateful ops and ops that require
resource

* Fix lint

* Added THasDeterministicOutput to multiple ops

* DDebug eliminate common expr

* Added test

* Expose get_optimized_symbol

* Fix

* Fix 2

* Add doc to the Python call

* Add env var MXNET_ELIMINATE_COMMON_EXPR, default true

* Add comments, improve readability of eliminate_common_expr_pass.cc

* Expand testing

* Lower priority of THasDeterministicOutput attr for equal Node test

* Change mx.gpu() to mx.cpu() in tests

* Skip CSE test on Windows (as env variable setting during test does not work there)

* Add missing import sys

* Add missing import logging
  • Loading branch information
ptrendx committed Nov 1, 2019
1 parent 9f6070f commit 1aa1b5a
Show file tree
Hide file tree
Showing 42 changed files with 421 additions and 3 deletions.
4 changes: 4 additions & 0 deletions docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
- Only applies to MXNet that has been compiled with CUDA and when ```MXNET_USE_FUSION``` option is enabled.
- If this variable is set, MXNet will print the code for fused operators that it generated.

* MXNET_ELIMINATE_COMMON_EXPR
- Values: 0(false) or 1(true) ```(default=1)```
- If this variable is set, MXNet will simplify the computation graph, eliminating duplicated operations on the same inputs.

Settings for Minimum Memory Usage
---------------------------------
- Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
Expand Down
11 changes: 11 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,17 @@ using FCreateOpState = std::function<OpStatePtr (const NodeAttrs& attrs,
Context ctx,
const mxnet::ShapeVector& in_shape,
const std::vector<int>& in_type)>;

/*!
* \brief Whether the operator always produces the same
* output given the same input.
* This enables certain optimizations
* like common expression elimination.
*
* \note Register under "THasDeterministicOutput"
*/
using THasDeterministicOutput = bool;

/*!
* \brief Execution mode of this operator.
*/
Expand Down
16 changes: 15 additions & 1 deletion python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import copy
import numpy as np
from .base import _LIB
from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str, mx_int
from .base import mx_uint, NDArrayHandle, SymbolHandle, ExecutorHandle, py_str, mx_int
from .base import check_call, c_handle_array, c_array_buf, c_str_array
from .ndarray import NDArray
from .ndarray import _ndarray_cls
Expand Down Expand Up @@ -511,3 +511,17 @@ def debug_str(self):
check_call(_LIB.MXExecutorPrint(
self.handle, ctypes.byref(debug_str)))
return py_str(debug_str.value)

def get_optimized_symbol(self):
"""Get an optimized version of the symbol from the executor.
Returns
-------
symbol : Symbol
Optimized symbol from the executor.
"""
from .symbol import Symbol
sym_handle = SymbolHandle()
check_call(_LIB.MXExecutorGetOptimizedSymbol(self.handle, ctypes.byref(sym_handle)))
ret = Symbol(sym_handle)
return ret
224 changes: 224 additions & 0 deletions src/executor/eliminate_common_expr_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file eliminate_common_expr.cc
* \brief Eliminate common expressions in the graph
* \author Przemyslaw Tredak
*/

#include <mxnet/base.h>
#include <mxnet/op_attr_types.h>

#include <vector>
#include <map>
#include <utility>
#include <sstream>

namespace mxnet {
namespace exec {

namespace {

using nnvm::Node;
using nnvm::NodePtr;
using nnvm::Graph;
using nnvm::IndexedGraph;

// NodeInput holds the sufficient subset of NodeEntry fields for Node-input equality tests
using NodeInput = std::pair<const Node*, uint32_t>;

/*!
* \brief Convert a Node's input vector of `NodeEntry` to a vector of the simpler `NodeInput`
*/
std::vector<NodeInput> ConvertInputs(const std::vector<nnvm::NodeEntry>& inputs) {
std::vector<NodeInput> ret;
for (const auto& entry : inputs) {
ret.emplace_back(entry.node.get(), entry.index);
}
return ret;
}

/*!
* \brief Determine if two Nodes have equal function such that one Node can be eliminated.
*/
bool NodeEqual(const Node* n, const Node* m) {
if (n->is_variable() || m->is_variable()) return false;
if (n->op() != m->op()) return false;
// Nodes with different attributes are considered not identical,
// though this may reject Node pairs that are in fact functionally the same.
if (n->attrs.dict != m->attrs.dict) return false;

// Ops that mutate inputs cannot be optimized out
static auto& fmutate_inputs = Op::GetAttr<nnvm::FMutateInputs>("FMutateInputs");
if (fmutate_inputs.get(n->op(), nullptr) != nullptr) return false;

// Stateful ops cannot be be equal to each other
static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
if (fstateful.get(n->op(), nullptr) != nullptr)
return false;

// Check to see if the user has explicitly set THasDeterministicOutput to override the
// subsequent determination of Node equality based on resource use.
static auto& deterministic_output =
Op::GetAttr<THasDeterministicOutput>("THasDeterministicOutput");
if (deterministic_output.contains(n->op()))
return deterministic_output[n->op()];

// Ops that require resource could ask for
// random resource, so need to be explicitly marked
// to be eligible
static auto& resource_request = Op::GetAttr<FResourceRequest>("FResourceRequest");
static auto& resource_request_ex = Op::GetAttr<FResourceRequestEx>("FResourceRequestEx");
if (resource_request.get(n->op(), nullptr) != nullptr) return false;
if (resource_request_ex.get(n->op(), nullptr) != nullptr) return false;

return true;
}

// Graph traversal to create a list of pairs of identical-function nodes that can be combined.
std::vector<std::pair<NodePtr, NodePtr> > GetCommonNodes(const Graph& g) {
std::vector<std::pair<NodePtr, NodePtr> > ret;
// A map between a vector of inputs and those nodes that have those inputs
std::map<std::vector<NodeInput>, std::vector<const NodePtr*> > grouped_nodes;
// Traverse the graph and group the nodes by their vector of inputs
nnvm::DFSVisit(g.outputs, [&grouped_nodes](const NodePtr& n) {
if (n->inputs.size() != 0) {
grouped_nodes[ConvertInputs(n->inputs)].push_back(&n);
}
});
// Now check for identical node ops within the node groups (having identical inputs)
for (const auto& pair : grouped_nodes) {
auto &node_group = pair.second; // Group of nodes that share the same vector of inputs
if (node_group.size() > 1) {
std::unordered_set<size_t> visited;
for (size_t i = 0; i < node_group.size(); ++i) {
if (visited.count(i)) continue;
for (size_t j = i + 1; j < node_group.size(); ++j) {
// If the two Nodes have equal function, then one Node (called the 'replaced') can
// be eliminated in favor of the other Node (the 'src').
if (NodeEqual(node_group[i]->get(), node_group[j]->get())) {
visited.insert(j);
NodePtr src = *node_group[i];
NodePtr replaced = *node_group[j];
ret.emplace_back(src, replaced);
}
}
}
}
}
return ret;
}

/*!
* \brief Do a single pass of Node elimination given pairs of identical Nodes.
*/
void EliminateCommonNodes(Graph* g,
const std::vector<std::pair<NodePtr, NodePtr> >& common_nodes) {
for (const auto &p : common_nodes) {
std::vector <NodePtr> nodes_to_change;
const NodePtr &src = p.first;
const NodePtr &replaced = p.second;
// Create a `nodes_to_change` list containing the Nodes that refer to the `replaced` Node
// that is targeted for elimination.
DFSVisit(g->outputs, [replaced, &nodes_to_change](const NodePtr &n) {
for (const auto &dep : n->control_deps) {
if (dep == replaced) {
nodes_to_change.push_back(n);
return;
}
}
for (const auto &inp : n->inputs) {
if (inp.node == replaced) {
nodes_to_change.push_back(n);
return;
}
}
});

// Change references to the `replaced` Node within the `nodes_to_change` list to be
// references to the equivalent `src` Node.
for (auto &n : nodes_to_change) {
for (auto &dep : n->control_deps) {
if (dep == replaced) {
dep = src;
}
}
for (auto &inp : n->inputs) {
if (inp.node == replaced) {
inp.node = src;
}
}
}

// Add `replaced` Node control dependencies to those of the `src` Node.
for (const auto &n : replaced->control_deps) {
src->control_deps.push_back(n);
}

// Change graph outputs driven by the `replaced` Node to now point to the `src` Node.
for (auto& out : g->outputs) {
if (out.node == replaced) {
out.node = src;
}
}
}
// Check for duplicates in outputs and
// insert Copy nodes as appropriate
const Op* copy_op = Op::Get("_copy");
nnvm::NodeEntryMap<size_t> unique_outputs;
for (size_t i = 0; i < g->outputs.size(); ++i) {
auto kv = unique_outputs.find(g->outputs[i]);
if (kv == unique_outputs.end()) {
unique_outputs.emplace(g->outputs[i], 0);
} else {
NodePtr copy_node = Node::Create();
std::ostringstream os;
os << kv->first.node->attrs.name << "_" << kv->second << "_copy";
kv->second++;
copy_node->attrs.op = copy_op;
copy_node->attrs.name = os.str();
copy_node->inputs.emplace_back(kv->first);
g->outputs[i] = nnvm::NodeEntry{copy_node, 0, 0};
}
}
}

} // namespace

/*!
* \brief Simplify a graph by iteratively eliminating Nodes with identical inputs and function.
*/
nnvm::Graph EliminateCommonExpr(nnvm::Graph&& g) {
using nnvm::NodePtr;
bool keep_running = true;
while (keep_running) {
const auto& common_nodes = GetCommonNodes(g);
if (common_nodes.empty()) {
keep_running = false;
} else {
EliminateCommonNodes(&g, common_nodes);
}
}
return g;
}

} // namespace exec
} // namespace mxnet
9 changes: 9 additions & 0 deletions src/executor/exec_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,15 @@ void AttachOpResources(const Graph& g,
*/
Graph DetectInplaceAddTo(Graph g);

/*!
* \brief Eliminate common expressions in the graph.
*
* \param g input forward graph
*
* \return graph with common expressions eliminated
*/
Graph EliminateCommonExpr(Graph && g);

/*!
* \brief Fuse pointwise operations in the forward pass.
*
Expand Down
3 changes: 3 additions & 0 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,

nnvm::Graph g;
g.outputs = symbol.outputs;
bool do_elim_common_expr = dmlc::GetEnv("MXNET_ELIMINATE_COMMON_EXPR", true);
if (do_elim_common_expr)
g = exec::EliminateCommonExpr(std::move(g));
need_grad_ = false;
for (OpReqType req : grad_req_types) {
if (req != kNullOp) need_grad_ = true;
Expand Down
6 changes: 5 additions & 1 deletion src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ void CreateFullGraph(const nnvm::Symbol& sym,
}
}

bool do_elim_common_expr = dmlc::GetEnv("MXNET_ELIMINATE_COMMON_EXPR", true);
if (do_elim_common_expr)
*fwd_graph = exec::EliminateCommonExpr(std::move(*fwd_graph));

// construct backward graph
{
ograd_entries->reserve(fwd_graph->outputs.size());
Expand Down Expand Up @@ -278,7 +282,7 @@ CachedOp::CachedOp(

auto grad_graph = nnvm::Graph();
std::unordered_map<uint32_t, uint32_t> fwd_input_to_grad_output;
CreateFullGraph(sym, &fwd_graph_, &grad_graph, &full_graph_,
CreateFullGraph(sym.Copy(), &fwd_graph_, &grad_graph, &full_graph_,
&ograd_entries_, &fwd_input_to_grad_output);

{
Expand Down
1 change: 1 addition & 0 deletions src/operator/contrib/boolean_mask.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ NNVM_REGISTER_OP(_contrib_boolean_mask)
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<FComputeEx>("FComputeEx<gpu>", BooleanMaskForward<gpu>);

NNVM_REGISTER_OP(_backward_contrib_boolean_mask)
Expand Down
2 changes: 2 additions & 0 deletions src/operator/contrib/bounding_box.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ Examples::
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<FCompute>("FCompute<cpu>", BoxNMSForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_contrib_box_nms"})
.add_argument("data", "NDArray-or-Symbol", "The input")
Expand Down Expand Up @@ -186,6 +187,7 @@ NNVM_REGISTER_OP(_contrib_bipartite_matching)
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<mxnet::FInferShape>("FInferShape", MatchingShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 2>)
.set_attr<FCompute>("FCompute<cpu>", BipartiteMatchingForward<cpu>)
Expand Down
1 change: 1 addition & 0 deletions src/operator/contrib/hawkes_ll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ Example::
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::Type::kTempSpace};
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.add_argument(
"lda", "NDArray-or-Symbol",
"Shape (N, K) The intensity for each of the K processes, for each sample"
Expand Down
1 change: 1 addition & 0 deletions src/operator/contrib/index_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ Examples::
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.add_argument("data", "NDArray-or-Symbol", "Input data")
.add_arguments(IndexArrayParam::__FIELDS__());

Expand Down
1 change: 1 addition & 0 deletions src/operator/loss_binary_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Example::
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCrossEntropyForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_softmax_cross_entropy"})
.set_attr<nnvm::FListInputNames>("FListInputNames",
Expand Down
2 changes: 2 additions & 0 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ Example::
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<bool>("TIsMKLDNN", true)
#endif // MXNET_USE_MKLDNN == 1
CONCAT_FORWARD_ATTRS
Expand Down Expand Up @@ -422,6 +423,7 @@ NNVM_REGISTER_OP(_rnn_param_concat)
})
#endif // MXNET_USE_MKLDNN == 1
CONCAT_FORWARD_ATTRS
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<mxnet::FInferShape>("FInferShape", RNNParamConcatShape)
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());
Expand Down
1 change: 1 addition & 0 deletions src/operator/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ There are other options to tune the performance.
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.add_argument("data", "NDArray-or-Symbol", "Input data to the ConvolutionOp.")
.add_argument("weight", "NDArray-or-Symbol", "Weight matrix.")
.add_argument("bias", "NDArray-or-Symbol", "Bias parameter.")
Expand Down
Loading

0 comments on commit 1aa1b5a

Please sign in to comment.