diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md index 04678d9962b2..e4fe58a116c5 100644 --- a/docs/static_site/src/pages/api/faq/env_var.md +++ b/docs/static_site/src/pages/api/faq/env_var.md @@ -339,6 +339,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`. - Only applies to MXNet that has been compiled with CUDA and when ```MXNET_USE_FUSION``` option is enabled. - If this variable is set, MXNet will print the code for fused operators that it generated. +* MXNET_ELIMINATE_COMMON_EXPR + - Values: 0(false) or 1(true) ```(default=1)``` + - If this variable is set, MXNet will simplify the computation graph, eliminating duplicated operations on the same inputs. + Settings for Minimum Memory Usage --------------------------------- - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1``` diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 75d843c98bd2..7c0ea77dc986 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -218,6 +218,17 @@ using FCreateOpState = std::function& in_type)>; + +/*! + * \brief Whether the operator always produces the same + * output given the same input. + * This enables certain optimizations + * like common expression elimination. + * + * \note Register under "THasDeterministicOutput" + */ +using THasDeterministicOutput = bool; + /*! * \brief Execution mode of this operator. */ diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index edc10dff18c2..3b79f0c8d1b4 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -25,7 +25,7 @@ import copy import numpy as np from .base import _LIB -from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str, mx_int +from .base import mx_uint, NDArrayHandle, SymbolHandle, ExecutorHandle, py_str, mx_int from .base import check_call, c_handle_array, c_array_buf, c_str_array from .ndarray import NDArray from .ndarray import _ndarray_cls @@ -511,3 +511,17 @@ def debug_str(self): check_call(_LIB.MXExecutorPrint( self.handle, ctypes.byref(debug_str))) return py_str(debug_str.value) + + def get_optimized_symbol(self): + """Get an optimized version of the symbol from the executor. + + Returns + ------- + symbol : Symbol + Optimized symbol from the executor. + """ + from .symbol import Symbol + sym_handle = SymbolHandle() + check_call(_LIB.MXExecutorGetOptimizedSymbol(self.handle, ctypes.byref(sym_handle))) + ret = Symbol(sym_handle) + return ret diff --git a/src/executor/eliminate_common_expr_pass.cc b/src/executor/eliminate_common_expr_pass.cc new file mode 100644 index 000000000000..5c77ec25b325 --- /dev/null +++ b/src/executor/eliminate_common_expr_pass.cc @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file eliminate_common_expr.cc + * \brief Eliminate common expressions in the graph + * \author Przemyslaw Tredak + */ + +#include +#include + +#include +#include +#include +#include + +namespace mxnet { +namespace exec { + +namespace { + +using nnvm::Node; +using nnvm::NodePtr; +using nnvm::Graph; +using nnvm::IndexedGraph; + +// NodeInput holds the sufficient subset of NodeEntry fields for Node-input equality tests +using NodeInput = std::pair; + +/*! + * \brief Convert a Node's input vector of `NodeEntry` to a vector of the simpler `NodeInput` + */ +std::vector ConvertInputs(const std::vector& inputs) { + std::vector ret; + for (const auto& entry : inputs) { + ret.emplace_back(entry.node.get(), entry.index); + } + return ret; +} + +/*! + * \brief Determine if two Nodes have equal function such that one Node can be eliminated. + */ +bool NodeEqual(const Node* n, const Node* m) { + if (n->is_variable() || m->is_variable()) return false; + if (n->op() != m->op()) return false; + // Nodes with different attributes are considered not identical, + // though this may reject Node pairs that are in fact functionally the same. + if (n->attrs.dict != m->attrs.dict) return false; + + // Ops that mutate inputs cannot be optimized out + static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); + if (fmutate_inputs.get(n->op(), nullptr) != nullptr) return false; + + // Stateful ops cannot be be equal to each other + static auto& fstateful = Op::GetAttr("FCreateOpState"); + if (fstateful.get(n->op(), nullptr) != nullptr) + return false; + + // Check to see if the user has explicitly set THasDeterministicOutput to override the + // subsequent determination of Node equality based on resource use. + static auto& deterministic_output = + Op::GetAttr("THasDeterministicOutput"); + if (deterministic_output.contains(n->op())) + return deterministic_output[n->op()]; + + // Ops that require resource could ask for + // random resource, so need to be explicitly marked + // to be eligible + static auto& resource_request = Op::GetAttr("FResourceRequest"); + static auto& resource_request_ex = Op::GetAttr("FResourceRequestEx"); + if (resource_request.get(n->op(), nullptr) != nullptr) return false; + if (resource_request_ex.get(n->op(), nullptr) != nullptr) return false; + + return true; +} + +// Graph traversal to create a list of pairs of identical-function nodes that can be combined. +std::vector > GetCommonNodes(const Graph& g) { + std::vector > ret; + // A map between a vector of inputs and those nodes that have those inputs + std::map, std::vector > grouped_nodes; + // Traverse the graph and group the nodes by their vector of inputs + nnvm::DFSVisit(g.outputs, [&grouped_nodes](const NodePtr& n) { + if (n->inputs.size() != 0) { + grouped_nodes[ConvertInputs(n->inputs)].push_back(&n); + } + }); + // Now check for identical node ops within the node groups (having identical inputs) + for (const auto& pair : grouped_nodes) { + auto &node_group = pair.second; // Group of nodes that share the same vector of inputs + if (node_group.size() > 1) { + std::unordered_set visited; + for (size_t i = 0; i < node_group.size(); ++i) { + if (visited.count(i)) continue; + for (size_t j = i + 1; j < node_group.size(); ++j) { + // If the two Nodes have equal function, then one Node (called the 'replaced') can + // be eliminated in favor of the other Node (the 'src'). + if (NodeEqual(node_group[i]->get(), node_group[j]->get())) { + visited.insert(j); + NodePtr src = *node_group[i]; + NodePtr replaced = *node_group[j]; + ret.emplace_back(src, replaced); + } + } + } + } + } + return ret; +} + +/*! + * \brief Do a single pass of Node elimination given pairs of identical Nodes. + */ +void EliminateCommonNodes(Graph* g, + const std::vector >& common_nodes) { + for (const auto &p : common_nodes) { + std::vector nodes_to_change; + const NodePtr &src = p.first; + const NodePtr &replaced = p.second; + // Create a `nodes_to_change` list containing the Nodes that refer to the `replaced` Node + // that is targeted for elimination. + DFSVisit(g->outputs, [replaced, &nodes_to_change](const NodePtr &n) { + for (const auto &dep : n->control_deps) { + if (dep == replaced) { + nodes_to_change.push_back(n); + return; + } + } + for (const auto &inp : n->inputs) { + if (inp.node == replaced) { + nodes_to_change.push_back(n); + return; + } + } + }); + + // Change references to the `replaced` Node within the `nodes_to_change` list to be + // references to the equivalent `src` Node. + for (auto &n : nodes_to_change) { + for (auto &dep : n->control_deps) { + if (dep == replaced) { + dep = src; + } + } + for (auto &inp : n->inputs) { + if (inp.node == replaced) { + inp.node = src; + } + } + } + + // Add `replaced` Node control dependencies to those of the `src` Node. + for (const auto &n : replaced->control_deps) { + src->control_deps.push_back(n); + } + + // Change graph outputs driven by the `replaced` Node to now point to the `src` Node. + for (auto& out : g->outputs) { + if (out.node == replaced) { + out.node = src; + } + } + } + // Check for duplicates in outputs and + // insert Copy nodes as appropriate + const Op* copy_op = Op::Get("_copy"); + nnvm::NodeEntryMap unique_outputs; + for (size_t i = 0; i < g->outputs.size(); ++i) { + auto kv = unique_outputs.find(g->outputs[i]); + if (kv == unique_outputs.end()) { + unique_outputs.emplace(g->outputs[i], 0); + } else { + NodePtr copy_node = Node::Create(); + std::ostringstream os; + os << kv->first.node->attrs.name << "_" << kv->second << "_copy"; + kv->second++; + copy_node->attrs.op = copy_op; + copy_node->attrs.name = os.str(); + copy_node->inputs.emplace_back(kv->first); + g->outputs[i] = nnvm::NodeEntry{copy_node, 0, 0}; + } + } +} + +} // namespace + +/*! + * \brief Simplify a graph by iteratively eliminating Nodes with identical inputs and function. + */ +nnvm::Graph EliminateCommonExpr(nnvm::Graph&& g) { + using nnvm::NodePtr; + bool keep_running = true; + while (keep_running) { + const auto& common_nodes = GetCommonNodes(g); + if (common_nodes.empty()) { + keep_running = false; + } else { + EliminateCommonNodes(&g, common_nodes); + } + } + return g; +} + +} // namespace exec +} // namespace mxnet diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 25a326171510..a5f125affcb0 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -194,6 +194,15 @@ void AttachOpResources(const Graph& g, */ Graph DetectInplaceAddTo(Graph g); +/*! + * \brief Eliminate common expressions in the graph. + * + * \param g input forward graph + * + * \return graph with common expressions eliminated + */ +Graph EliminateCommonExpr(Graph && g); + /*! * \brief Fuse pointwise operations in the forward pass. * diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 4f1553bc19d5..7fa1de373d07 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -331,6 +331,9 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, nnvm::Graph g; g.outputs = symbol.outputs; + bool do_elim_common_expr = dmlc::GetEnv("MXNET_ELIMINATE_COMMON_EXPR", true); + if (do_elim_common_expr) + g = exec::EliminateCommonExpr(std::move(g)); need_grad_ = false; for (OpReqType req : grad_req_types) { if (req != kNullOp) need_grad_ = true; diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index dd392d3e0401..269729c18f58 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -93,6 +93,10 @@ void CreateFullGraph(const nnvm::Symbol& sym, } } + bool do_elim_common_expr = dmlc::GetEnv("MXNET_ELIMINATE_COMMON_EXPR", true); + if (do_elim_common_expr) + *fwd_graph = exec::EliminateCommonExpr(std::move(*fwd_graph)); + // construct backward graph { ograd_entries->reserve(fwd_graph->outputs.size()); @@ -278,7 +282,7 @@ CachedOp::CachedOp( auto grad_graph = nnvm::Graph(); std::unordered_map fwd_input_to_grad_output; - CreateFullGraph(sym, &fwd_graph_, &grad_graph, &full_graph_, + CreateFullGraph(sym.Copy(), &fwd_graph_, &grad_graph, &full_graph_, &ograd_entries_, &fwd_input_to_grad_output); { diff --git a/src/operator/contrib/boolean_mask.cu b/src/operator/contrib/boolean_mask.cu index f6c1df0c62a8..a5ef4a70d99b 100644 --- a/src/operator/contrib/boolean_mask.cu +++ b/src/operator/contrib/boolean_mask.cu @@ -157,6 +157,7 @@ NNVM_REGISTER_OP(_contrib_boolean_mask) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FComputeEx", BooleanMaskForward); NNVM_REGISTER_OP(_backward_contrib_boolean_mask) diff --git a/src/operator/contrib/bounding_box.cc b/src/operator/contrib/bounding_box.cc index 62b7c2e0bf4b..3ab11bb2d6f9 100644 --- a/src/operator/contrib/bounding_box.cc +++ b/src/operator/contrib/bounding_box.cc @@ -102,6 +102,7 @@ Examples:: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", BoxNMSForward) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_contrib_box_nms"}) .add_argument("data", "NDArray-or-Symbol", "The input") @@ -186,6 +187,7 @@ NNVM_REGISTER_OP(_contrib_bipartite_matching) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FInferShape", MatchingShape) .set_attr("FInferType", ElemwiseType<1, 2>) .set_attr("FCompute", BipartiteMatchingForward) diff --git a/src/operator/contrib/hawkes_ll.cc b/src/operator/contrib/hawkes_ll.cc index 758ab2012580..1e2fff5c9871 100644 --- a/src/operator/contrib/hawkes_ll.cc +++ b/src/operator/contrib/hawkes_ll.cc @@ -104,6 +104,7 @@ Example:: .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::Type::kTempSpace}; }) + .set_attr("THasDeterministicOutput", true) .add_argument( "lda", "NDArray-or-Symbol", "Shape (N, K) The intensity for each of the K processes, for each sample" diff --git a/src/operator/contrib/index_array.cc b/src/operator/contrib/index_array.cc index a70dee106314..ef4f030863f2 100644 --- a/src/operator/contrib/index_array.cc +++ b/src/operator/contrib/index_array.cc @@ -163,6 +163,7 @@ Examples:: .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .add_argument("data", "NDArray-or-Symbol", "Input data") .add_arguments(IndexArrayParam::__FIELDS__()); diff --git a/src/operator/loss_binary_op.cc b/src/operator/loss_binary_op.cc index 696c8589a0dc..5bf49669db89 100644 --- a/src/operator/loss_binary_op.cc +++ b/src/operator/loss_binary_op.cc @@ -65,6 +65,7 @@ Example:: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", SoftmaxCrossEntropyForward) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_softmax_cross_entropy"}) .set_attr("FListInputNames", diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index fa62b0044a53..4d90810915a2 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -385,6 +385,7 @@ Example:: .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("TIsMKLDNN", true) #endif // MXNET_USE_MKLDNN == 1 CONCAT_FORWARD_ATTRS @@ -422,6 +423,7 @@ NNVM_REGISTER_OP(_rnn_param_concat) }) #endif // MXNET_USE_MKLDNN == 1 CONCAT_FORWARD_ATTRS +.set_attr("THasDeterministicOutput", true) .set_attr("FInferShape", RNNParamConcatShape) .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") .add_arguments(ConcatParam::__FIELDS__()); diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index e31073034594..6d9f84ffc510 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -503,6 +503,7 @@ There are other options to tune the performance. .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .add_argument("data", "NDArray-or-Symbol", "Input data to the ConvolutionOp.") .add_argument("weight", "NDArray-or-Symbol", "Weight matrix.") .add_argument("bias", "NDArray-or-Symbol", "Bias parameter.") diff --git a/src/operator/nn/ctc_loss.cc b/src/operator/nn/ctc_loss.cc index f718b42bfaa4..aba76fb0c452 100644 --- a/src/operator/nn/ctc_loss.cc +++ b/src/operator/nn/ctc_loss.cc @@ -115,6 +115,7 @@ information on the definition and the algorithm. .set_attr("FInferStorageType", CTCLossOpStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", CTCLossOpForward) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_ctc_loss"}) .add_argument("data", "NDArray-or-Symbol", "Input ndarray") diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index b61f9ff37002..bbcec53e933d 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -430,6 +430,7 @@ NNVM_REGISTER_OP(Deconvolution) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", DeconvolutionCompute) .set_attr("FGradient", DeconvolutionGrad{"_backward_Deconvolution"}) #if MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index 1f6d9e313202..5d722581257f 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -314,6 +314,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored. return std::vector{ResourceRequest::kTempSpace}; }) #endif +.set_attr("THasDeterministicOutput", true) .set_attr("FInferShape", FullyConnectedShape) .set_attr("FInferType", FullyConnectedType) .set_attr("FCompute", FullyConnectedCompute) diff --git a/src/operator/nn/group_norm.cc b/src/operator/nn/group_norm.cc index b4698abeff83..06430c281920 100644 --- a/src/operator/nn/group_norm.cc +++ b/src/operator/nn/group_norm.cc @@ -111,6 +111,7 @@ Both ``gamma`` and ``beta`` are learnable parameters. .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .add_argument("data", "NDArray-or-Symbol", "Input data") .add_argument("gamma", "NDArray-or-Symbol", "gamma array") .add_argument("beta", "NDArray-or-Symbol", "beta array") diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index 0b53d5091194..1b2a43b2501c 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -194,6 +194,7 @@ axis to be the last item in the input shape. .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .add_argument("data", "NDArray-or-Symbol", "Input data to layer normalization") .add_argument("gamma", "NDArray-or-Symbol", "gamma array") .add_argument("beta", "NDArray-or-Symbol", "beta array") diff --git a/src/operator/nn/moments.cc b/src/operator/nn/moments.cc index 37b8cdf18750..180615e53d61 100644 --- a/src/operator/nn/moments.cc +++ b/src/operator/nn/moments.cc @@ -66,6 +66,7 @@ If x is 1-D and axes = [0] this is just the mean and variance of a vector. [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ElemwiseGradUseInOut{"_backward_moments"}) .set_attr("FInplaceOption", [](const NodeAttrs& attrs) { diff --git a/src/operator/nn/softmax_activation.cc b/src/operator/nn/softmax_activation.cc index 8a28243dfced..9e5a3ab8f6a2 100644 --- a/src/operator/nn/softmax_activation.cc +++ b/src/operator/nn/softmax_activation.cc @@ -75,6 +75,7 @@ NNVM_REGISTER_OP(_backward_SoftmaxActivation) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxActivationGradCompute); diff --git a/src/operator/nn/upsampling.cc b/src/operator/nn/upsampling.cc index 971ff6ad560b..d36b2598ce82 100644 --- a/src/operator/nn/upsampling.cc +++ b/src/operator/nn/upsampling.cc @@ -195,6 +195,7 @@ Example:: return std::vector{ResourceRequest::kTempSpace}; } }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", UpSamplingCompute) .set_attr("FGradient", UpSamplingGrad{"_backward_UpSampling"}) .set_attr("key_var_num_args", "num_args") diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index fdda792a9ed8..435fe1df1134 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -138,6 +138,7 @@ NNVM_REGISTER_OP(_np_sum) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_np_sum"}); NNVM_REGISTER_OP(_backward_np_sum) @@ -176,6 +177,7 @@ NNVM_REGISTER_OP(_np_max) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ReduceGrad{"_backward_np_max"}); NNVM_REGISTER_OP(_backward_np_max) @@ -203,6 +205,7 @@ return std::vector{"a"}; [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ReduceGrad{"_backward_np_min"}); NNVM_REGISTER_OP(_backward_np_min) @@ -229,6 +232,7 @@ NNVM_REGISTER_OP(_np_prod) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ReduceGrad{"_backward_np_prod"}); NNVM_REGISTER_OP(_backward_np_prod) @@ -282,6 +286,7 @@ NNVM_REGISTER_OP(_npi_mean) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_np_mean"}); NNVM_REGISTER_OP(_backward_np_mean) @@ -350,6 +355,7 @@ NNVM_REGISTER_OP(_npi_std) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", MakeZeroGradNodes); NNVM_REGISTER_OP(_npi_var) @@ -377,6 +383,7 @@ NNVM_REGISTER_OP(_npi_var) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", MakeZeroGradNodes); bool NumpyBroadcastToShape(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/numpy/np_dot.cc b/src/operator/numpy/np_dot.cc index 6afc896a7720..feb032ae07ea 100644 --- a/src/operator/numpy/np_dot.cc +++ b/src/operator/numpy/np_dot.cc @@ -131,6 +131,7 @@ NNVM_REGISTER_OP(_np_dot) [](const NodeAttrs& attrs) { return std::vector(1, ResourceRequest::kTempSpace); }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", NumpyDotForward) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_np_dot"}) .add_argument("a", "NDArray-or-Symbol", "First input") diff --git a/src/operator/numpy/np_tensordot_op.cc b/src/operator/numpy/np_tensordot_op.cc index aca45c1652ee..96de0decf73a 100644 --- a/src/operator/numpy/np_tensordot_op.cc +++ b/src/operator/numpy/np_tensordot_op.cc @@ -113,6 +113,7 @@ NNVM_REGISTER_OP(_npi_tensordot) [](const NodeAttrs& attrs) { return std::vector(1, ResourceRequest::kTempSpace); }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", TensordotOpForward) .set_attr("FGradient", mxnet::op::ElemwiseGradUseIn{"_backward_npi_tensordot"}) .add_argument("a", "NDArray-or-Symbol", "First input") @@ -213,6 +214,7 @@ NNVM_REGISTER_OP(_npi_tensordot_int_axes) [](const NodeAttrs& attrs) { return std::vector(1, ResourceRequest::kTempSpace); }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", TensordotIntAxesOpForward) .set_attr("FGradient", mxnet::op::ElemwiseGradUseIn{"_backward_npi_tensordot_int_axes"}) diff --git a/src/operator/tensor/broadcast_reduce_minmax_value.cc b/src/operator/tensor/broadcast_reduce_minmax_value.cc index f8bc33ba375d..e77d42b042ae 100644 --- a/src/operator/tensor/broadcast_reduce_minmax_value.cc +++ b/src/operator/tensor/broadcast_reduce_minmax_value.cc @@ -35,6 +35,7 @@ MXNET_OPERATOR_REGISTER_MINMAX_REDUCE(max) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ReduceGrad{"_backward_max"}); MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_max) @@ -49,6 +50,7 @@ MXNET_OPERATOR_REGISTER_MINMAX_REDUCE(min) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ReduceGrad{"_backward_min"}); MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_min) diff --git a/src/operator/tensor/broadcast_reduce_norm_value.cc b/src/operator/tensor/broadcast_reduce_norm_value.cc index 63a05b4980fc..4cd92d44997e 100644 --- a/src/operator/tensor/broadcast_reduce_norm_value.cc +++ b/src/operator/tensor/broadcast_reduce_norm_value.cc @@ -98,6 +98,7 @@ Examples:: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", LpNormCompute) .set_attr("FComputeEx", L2NormComputeEx) .add_argument("data", "NDArray-or-Symbol", "The input") diff --git a/src/operator/tensor/broadcast_reduce_prod_value.cc b/src/operator/tensor/broadcast_reduce_prod_value.cc index 4778865bf11d..a38f37a3e55c 100644 --- a/src/operator/tensor/broadcast_reduce_prod_value.cc +++ b/src/operator/tensor/broadcast_reduce_prod_value.cc @@ -34,6 +34,7 @@ MXNET_OPERATOR_REGISTER_REDUCE(prod) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ReduceGrad{ "_backward_prod" }); MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_prod) @@ -49,6 +50,7 @@ MXNET_OPERATOR_REGISTER_REDUCE(nanprod) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ReduceGrad{ "_backward_nanprod" }); MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_nanprod) diff --git a/src/operator/tensor/broadcast_reduce_sum_value.cc b/src/operator/tensor/broadcast_reduce_sum_value.cc index c5c9f5cb48e4..53e37e437f96 100644 --- a/src/operator/tensor/broadcast_reduce_sum_value.cc +++ b/src/operator/tensor/broadcast_reduce_sum_value.cc @@ -72,6 +72,7 @@ Example:: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_sum"}); MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_sum) @@ -88,6 +89,7 @@ MXNET_ADD_SPARSE_OP_ALIAS(mean) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_mean"}); MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_mean) @@ -103,6 +105,7 @@ MXNET_OPERATOR_REGISTER_REDUCE(nansum) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ReduceGrad{ "_backward_nansum" }); MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_nansum) diff --git a/src/operator/tensor/cast_storage.cc b/src/operator/tensor/cast_storage.cc index 5d93979a5bb7..ce5025696619 100644 --- a/src/operator/tensor/cast_storage.cc +++ b/src/operator/tensor/cast_storage.cc @@ -79,6 +79,7 @@ Example:: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", UnaryOp::IdentityCompute) .set_attr("FComputeEx", CastStorageComputeEx) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}) diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc index 556260ed9600..32d1c81ed40b 100644 --- a/src/operator/tensor/dot.cc +++ b/src/operator/tensor/dot.cc @@ -89,6 +89,7 @@ above patterns, ``dot`` will fallback and generate output with default storage. [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", DotForward_) .set_attr("FComputeEx", DotForwardEx) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_dot"}) @@ -137,6 +138,7 @@ which is computed by:: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", BatchDotForward_) .set_attr("FGradient", [](const nnvm::NodePtr& n, diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index c5e30c68de7e..50772bc075d4 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -86,6 +86,7 @@ MXNET_OPERATOR_REGISTER_BINARY(elemwise_add) .set_attr("TIsMKLDNN", true) #endif .set_attr("FComputeEx", ElemwiseAddEx) +.set_attr("THasDeterministicOutput", true) .set_attr("FResourceRequest", /* For Sparse CSR */ [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace};}) @@ -232,6 +233,7 @@ The storage type of ``elemwise_mul`` output depends on storage types of inputs [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .add_alias("_mul").add_alias("_Mul") .set_attr("FGradient", ElemwiseGradUseIn{"_backward_mul"}); diff --git a/src/operator/tensor/elemwise_scatter_op.cc b/src/operator/tensor/elemwise_scatter_op.cc index dd6da0ce41aa..41f22b057a53 100644 --- a/src/operator/tensor/elemwise_scatter_op.cc +++ b/src/operator/tensor/elemwise_scatter_op.cc @@ -93,6 +93,7 @@ with default storage [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_div"}); /*! \brief _scatter_plus_scalar */ diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc index 75553ef2c2a5..d1b86d161e89 100644 --- a/src/operator/tensor/elemwise_sum.cc +++ b/src/operator/tensor/elemwise_sum.cc @@ -178,6 +178,7 @@ The storage type of ``add_n`` output depends on storage types of inputs [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) #endif diff --git a/src/operator/tensor/histogram.cc b/src/operator/tensor/histogram.cc index b7896e9e0016..78234873772d 100644 --- a/src/operator/tensor/histogram.cc +++ b/src/operator/tensor/histogram.cc @@ -152,6 +152,7 @@ Example:: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FInferShape", HistogramOpShape) .set_attr("FInferType", HistogramOpType) .set_attr("FCompute", HistogramOpForward) diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 470abee71a59..4bba683f0f28 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -551,6 +551,7 @@ The storage type of weight can be either row_sparse or default. [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", EmbeddingOpForward) .set_attr("FComputeEx", SparseEmbeddingOpForwardEx) .set_attr("FGradient", @@ -624,6 +625,7 @@ Examples:: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FInferShape", EmbeddingOpShape) .set_attr("FInferType", EmbeddingOpType) .set_attr("FInferStorageType", SparseEmbeddingOpForwardStorageType) @@ -728,6 +730,7 @@ The storage type of ``take`` output depends upon the input storage type: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", TakeOpForward) .set_attr("FComputeEx", TakeOpForwardEx) .set_attr("FGradient", diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index ce7d1d5de692..3d0e43251e03 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -806,6 +806,7 @@ Examples:: { return std::vector>{{0, 0}}; }) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", LaOpForward) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_linalg_gelqf"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices to be factorized"); @@ -875,6 +876,7 @@ Examples:: { return std::vector>{{0, 0}}; }) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", LaOpForwSyevd) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_linalg_syevd"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices to be factorized"); @@ -925,6 +927,7 @@ Examples:: { return std::vector>{{0, 0}}; }) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", LaOpForward) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_linalg_inverse"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); @@ -978,6 +981,7 @@ Examples:: .set_attr("FInferType", DetType<1>) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FCompute", LaOpDetForward) .set_attr("FGradient", ReduceDetGrad<1>{"_backward_linalg_det"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 0f63061d7c09..eee5ea67f6e1 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -196,6 +196,7 @@ If the argument `reverse` is set to 1, then the special values are inferred from [](const NodeAttrs& attrs){ return std::vector{true}; }) +.set_attr("THasDeterministicOutput", true) .add_argument("data", "NDArray-or-Symbol", "Input data to reshape.") .add_arguments(ReshapeParam::__FIELDS__()); @@ -269,6 +270,7 @@ Example:: [](const NodeAttrs& attrs){ return std::vector{true}; }) +.set_attr("THasDeterministicOutput", true) .add_argument("data", "NDArray-or-Symbol", "Input array."); #if MXNET_USE_MKLDNN == 1 @@ -484,6 +486,7 @@ Example:: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FInferStorageType", SliceForwardInferStorageType) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_slice"}) .set_attr("FCompute", SliceOpForward) @@ -836,6 +839,7 @@ Examples:: [](const NodeAttrs& attrs) { return std::vector {ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) .set_attr("FCompute", ReverseOpForward) @@ -977,6 +981,7 @@ Example:: [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ElemwiseGradUseNone{"space_to_depth"}) .add_argument("data", "NDArray-or-Symbol", "Input ndarray") .add_arguments(DepthToSpaceParam::__FIELDS__()); @@ -1023,6 +1028,7 @@ Example:: [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ElemwiseGradUseNone{"depth_to_space"}) .add_argument("data", "NDArray-or-Symbol", "Input ndarray") .add_arguments(DepthToSpaceParam::__FIELDS__()); @@ -1091,6 +1097,7 @@ Example:: [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FGradient", ElemwiseGradUseNone{"_split_v2_backward"}) .add_argument("data", "NDArray-or-Symbol", "The input") .add_arguments(SplitParam::__FIELDS__()); diff --git a/src/operator/tensor/ordering_op.cc b/src/operator/tensor/ordering_op.cc index e36416114e31..6c375ce8e3c2 100644 --- a/src/operator/tensor/ordering_op.cc +++ b/src/operator/tensor/ordering_op.cc @@ -91,6 +91,7 @@ Examples:: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .add_argument("data", "NDArray-or-Symbol", "The input array") .add_arguments(TopKParam::__FIELDS__()); @@ -154,6 +155,7 @@ Examples:: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .add_argument("data", "NDArray-or-Symbol", "The input array") .add_arguments(SortParam::__FIELDS__()); @@ -190,6 +192,7 @@ Examples:: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .add_argument("data", "NDArray-or-Symbol", "The input array") .add_arguments(ArgSortParam::__FIELDS__()); } // namespace op diff --git a/src/operator/tensor/ravel.cc b/src/operator/tensor/ravel.cc index 94d79c7d07a6..e04628efab92 100644 --- a/src/operator/tensor/ravel.cc +++ b/src/operator/tensor/ravel.cc @@ -45,6 +45,7 @@ Examples:: .set_attr_parser(ParamParser) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"data"}; } ) .set_attr("FInferShape", RavelOpShape) @@ -70,6 +71,7 @@ Examples:: .set_attr_parser(ParamParser) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"data"}; } ) .set_attr("FInferShape", UnravelOpShape) diff --git a/src/operator/tensor/square_sum.cc b/src/operator/tensor/square_sum.cc index af365bae05dc..255ec5bb8032 100644 --- a/src/operator/tensor/square_sum.cc +++ b/src/operator/tensor/square_sum.cc @@ -71,6 +71,7 @@ MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_square_sum) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("THasDeterministicOutput", true) .set_attr("FInferStorageType", SquareSumBackwardInferStorageType) .set_attr("FComputeEx", SquareSumOpBackwardEx); diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 48c4f1664226..a2aad2c079fc 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -16,7 +16,9 @@ # under the License. import copy +import sys import os +import logging import re import json import mxnet as mx @@ -391,7 +393,6 @@ def test_children_same_name(): for c in b.get_children(): pass - def test_transpose_nullop(): for dim in range(1, 7): a = mx.sym.Variable('a') @@ -417,6 +418,91 @@ def test_gen_atomic_symbol_multiple_outputs(): atomic_sym = s._gen_atomic_symbol() +def test_eliminate_common_expr(): + if not sys.platform.startswith('linux'): + logging.info("Bypass the CSE test on non-Linux OS as setting env variables during test does not work on Windows") + return + def set_back_env_var(var_name, old_env_var): + if old_env_var is None: + os.environ.pop(var_name) + else: + os.environ[var_name] = old_env_var + + # helper function to test a single model + def check_cse_on_symbol(sym, expected_savings, check_data, **kwargs): + inputs = sym.list_inputs() + shapes = {inp : kwargs[inp].shape for inp in inputs} + rtol = {'float16' : 1e-2, + 'float32' : 1.5e-6, + 'float64' : 1.5e-6, + } + atol = {'float16' : 1e-3, + 'float32' : 1e-7, + 'float64' : 1e-7, + } + env_var_name = 'MXNET_ELIMINATE_COMMON_EXPR' + old_env_var = os.environ.get(env_var_name, None) + try: + for dtype in ['float16', 'float32', 'float64']: + data = {inp : kwargs[inp].astype(dtype) for inp in inputs} + for grad_req in ['write', 'add']: + type_dict = {inp : dtype for inp in inputs} + os.environ[env_var_name] = '0' + orig_exec = sym.simple_bind(ctx=mx.cpu(0), grad_req=grad_req, + type_dict=type_dict, **shapes) + os.environ[env_var_name] = '1' + cse_exec = sym.simple_bind(ctx=mx.cpu(0), grad_req=grad_req, + type_dict=type_dict, **shapes) + fwd_orig = orig_exec.forward(is_train=True, **data) + out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig] + orig_exec.backward(out_grads=out_grads) + fwd_cse = cse_exec.forward(is_train=True, **data) + cse_exec.backward(out_grads=out_grads) + if check_data: + for orig, cse in zip(fwd_orig, fwd_cse): + np.testing.assert_allclose(orig.asnumpy(), cse.asnumpy(), + rtol=rtol[dtype], atol=atol[dtype]) + for orig, cse in zip(orig_exec.grad_arrays, cse_exec.grad_arrays): + if orig is None and cse is None: + continue + assert orig is not None + assert cse is not None + np.testing.assert_allclose(orig.asnumpy(), cse.asnumpy(), + rtol=rtol[dtype], atol=atol[dtype]) + orig_sym_internals = orig_exec.get_optimized_symbol().get_internals() + cse_sym_internals = cse_exec.get_optimized_symbol().get_internals() + # test that the graph has been simplified as expected + assert (len(cse_sym_internals) + expected_savings) == len(orig_sym_internals) + finally: + set_back_env_var(env_var_name, old_env_var) + + a = mx.sym.Variable('a') + b = mx.sym.Variable('b') + c = mx.sym.Variable('c') + shape = rand_shape_nd(2) + arr1 = mx.random.uniform(shape=shape) + arr2 = mx.random.uniform(shape=shape) + arr3 = mx.random.uniform(shape=shape) + + check_cse_on_symbol((a+5) + (a+5), expected_savings=1, check_data=True, a=arr1, b=arr2) + check_cse_on_symbol((a+1) + (a+2), expected_savings=0, check_data=True, a=arr1, b=arr2) + check_cse_on_symbol((1+a) + (a+1), expected_savings=1, check_data=True, a=arr1, b=arr2) + check_cse_on_symbol((a+b) + (a+b), expected_savings=1, check_data=True, a=arr1, b=arr2) + check_cse_on_symbol(((a+b)+c) +((a+b)+c), expected_savings=2, check_data=True, + a=arr1, b=arr2, c=arr3) + d = a + 1 + + # a*d node gets eliminated, but then a copy is inserted to isolate the outputs, so no net gain. + check_cse_on_symbol(mx.sym.Group([a*d, a*d]), expected_savings=0, check_data=True, a=arr1) + + # a*d node gets eliminated, then the duplicated add-of-b, but then a copy is added for net of 1. + check_cse_on_symbol(mx.sym.Group([a*d+b, a*d+b]), expected_savings=1, check_data=True, + a=arr1, b=arr2) + + # dropout uses a resource that precludes any optimization + check_cse_on_symbol(mx.sym.Dropout(a) + + mx.sym.Dropout(a), expected_savings=0, check_data=False, a=arr1) + def test_load_save_symbol(): batch_size = 10 num_hdidden = 128