Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Eliminate common expressions #15657

Merged
merged 21 commits into from
Nov 1, 2019
Merged

Conversation

ptrendx
Copy link
Member

@ptrendx ptrendx commented Jul 25, 2019

Description

This PR introduces a graph pass that eliminates redundant common expressions in the graph.

For example, let's look at the graph created from a following Python code:

a = mx.sym.Variable()
out = (a + 5) * (a + 5)

This graph will compute a+5 twice, which is wasteful. After the pass introduced in this PR, the actual executed graph will be equivalent to

a = mx.sym.Variable()
b = a + 5
out = b * b

which computes the a+5 only once.

@eric-haibin-lin FYI

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Code is well-documented:
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

@abhinavs95
Copy link
Contributor

@mxnet-label-bot add [Backend, pr-awaiting-review]

@marcoabreu marcoabreu added Backend Issues related to the backend of MXNet pr-awaiting-review PR is waiting for code review labels Jul 25, 2019
@ptrendx ptrendx requested a review from szha as a code owner September 11, 2019 00:14
@DickJC123
Copy link
Contributor

This will be an awesome addition. Some things to consider in polishing it:

  1. The crux is in the definition of functionally equal nodes. Do you think your method bool NodeEqual(const Node * n, const Node * m) belongs in the CSE code, or in the Node class?
  2. It's probably best to be conservative with the NodeEqual, but by comparing the attrs.dict, you will miss some equivalent nodes. To illustrate my point, I instrumented the reduce operator to print out the dict:
>>> x = mx.nd.array([1,2,3,4])
>>> mx.nd.sum(x)
[21:32:38] src/operator/tensor/./broadcast_reduce_op.h:621: node attr dict = {}
[10.] <NDArray 1 @cpu(0)>
>>> mx.nd.sum(x, axis=())
[21:32:53] src/operator/tensor/./broadcast_reduce_op.h:621: node attr dict = {{axis,()}}
[10.] <NDArray 1 @cpu(0)>
>>> mx.nd.sum(x, axis=0)
[21:33:07] src/operator/tensor/./broadcast_reduce_op.h:621: node attr dict = {{axis,0}}
[10.] <NDArray 1 @cpu(0)>
>>> mx.nd.sum(x, axis=(0,))
[21:33:39] src/operator/tensor/./broadcast_reduce_op.h:621: node attr dict = {{axis,(0,)}}
[10.] <NDArray 1 @cpu(0)>

All sum operators are functionally equivalent, but none would compare as equal with the approach that includes comparing the attrs.dict (the map<string,string> from the python operator).
Any chance you could move the equality to comparing the parameter struct?
3. I see you consider operators that have resources as never equal. Might you maintain a 'white list' of resources that are OK for equal nodes to have, e.g. the commonly used tempspace resource?
4. I'd prefer to see a test for each of the reasons you deny node equality (e.g. having mutable inputs). You could either count the nodes, to prove the CSE did not happen, or test against a golden copy since performing the CSE would break functionality. I see you have a test for a case where CSE should happen, but it only looks for the reduced node count without testing functionality.
5. Since CSE combines output nodes, those output nodes should not be mutable inputs of downstream nodes, e.g.

x = Variable('x')
xcopy1 = x.copy()
xcopy2 = x.copy()
y = SomeOp(...,some_mutable_input=xcopy1,...)
z = SomeOp(...,some_mutable_input=xcopy2,...)

If you combine equal nodes xcopy1 and xcopy2, then the y and z nodes will start seeing the effects of each other, which they shouldn't as originally coded.

@ptrendx
Copy link
Member Author

ptrendx commented Sep 12, 2019

Thanks @DickJC123! Your thoroughness is amazing as always :-). I will think about the other points you made, but point 5 is especially interesting. I did not think about it before and you are right that this is a potential problem. I will introduce an additional check for that.

@DickJC123
Copy link
Contributor

DickJC123 commented Sep 13, 2019

Another area to spend some extra thought on involves the primary outputs and the gradients of primary inputs. If you combined two output NDArrays via CSE, would there be any problem with someone wanting to index through the outputs? Ditto for input grads. If a CSE-optimized model generated NDArrays for another model (and that model had mutable inputs), would functionality be preserved? I think a downstream model gets a copy of the NDArray as the input, so it should work, but do we ever want to optimize away that copy?

CSE may eliminate downstream in-place options for node outputs that have been combined. Not sure if there's a case where using CSE is a net perf loss, but I doubt it.

And finally, here's a fairly obscure case where the operator state may not be contained entirely in the node attributes (we should probably outlaw this usage, since it assumes no caching of the env var lookup):

sym = ...
os.environ['MXNET_MY_OP_BEHAVIOR']=X
sym2 = MyOp(data=sym,...)                        # op behavior affected by captured env var setting
os.environ['MNXET_MY_OP_BEHAVIOR']=Y
sym3 = MyOp(data=sym,...)                        # op behavior affected by captured env var setting

Perhaps shifting from the node attr dict to the param struct would catch this case.

@ptrendx
Copy link
Member Author

ptrendx commented Sep 20, 2019

@DickJC123 Ok, answering your questions:

  1. This definition of Node equality is quite tied to what I want to do with them, so I believe, at least right now, it belongs to CSE code.
  2. While I agree it would be beneficial, it would also be pretty hard to do - the parameter structure is kept without the type information as the shared_ptr<dmlc::any>. Comparing memory (which could be dine if any gave size of the type held) is also not really feasible since dict is actually part of the Parameter structure, so that would be different.
  3. This is the purpose of this THasDeterministicOutput property.
  4. I will think about that.

@DickJC123
Copy link
Contributor

I believe the case of having one output driving multiple symbol outputs is handled strangely, not that this is an important use case. The issue occurs with an operator driving 3 or more symbol outputs. Here's the progression of model compute graphs:

foo ----+----> out0
        |
        +----> out1
        |
        +----> out2

copy operators inserted:

foo ----+----> out0
        |
        +-- copy --> out1
        |
        +-- copy --> out2

More CSE applied:

foo ----+----> out0
        |
        +-- copy --+----> out1
                   |
                   +----> out2

copy operators inserted:

foo ----+----> out0
        |
        +-- copy --+----> out1
                   |
                   +-- copy --> out2

So N outputs driven by the same node end up being driven by N - 1 serialized copies, not N-1 parallel copies.

What is the use-case that requires these copies to be inserted?

@DickJC123
Copy link
Contributor

The EliminateCommonExpr pass is performed first-thing, before the backward nodes have been inserted and so before ctx info (i.e. dev_mask) have been added to the graph. I think this is the correct place- doing CSE with backward nodes present would be a nightmare.

I'm not fond of the HasDeterministicOutput flag- seems like it's just one more thing the user has to understand and set correctly when making a new operator. I was hoping that perhaps this property could be inferred by the op's registered resources. Life would be simple if the resources were determined only by the FResourceRequest(attrs) callback. The problem is there's also the FResourceRequestEx(attrs, dev_mask, dispatch_mode) callback, which takes a dev_mask and dispatch_mode. As I point out above, CSE is performed before dev_mask is available from the graph. The only way I can see doing what I'm suggesting would be to infer the possible settings for dev_mask and dispatch_mode from the registration of other callbacks. Some possible pseudo-code for this:

set<resources> all_possible_resources = {}
if (op_has_registered("FResourceRequestEx") {
  auto f_resource_req_ex = < look up op's FResourceRequestEx >
  if (op_has_registered("FCompute<cpu>")
    all_possible_resources += f_resource_req_ex(attrs, kCpu, DispatchMode::kFCompute)
  if (op_has_registered("FComputeEx<cpu>")
    all_possible_resources += f_resource_req_ex(attrs, kCpu, DispatchMode::kFComputeEx)
  if (op_has_registered("FCompute<gpu>")
    all_possible_resources += f_resource_req_ex(attrs, kGpu, DispatchMode::kFCompute)
  if (op_has_registered("FComputeEx<gpu>")
    all_possible_resources += f_resource_req_ex(attrs, kGpu, DispatchMode::kFComputeEx)
} else if op_has_registered("FResourceRequest") {
  auto f_resource_req = < look up op's FResourceRequest >
  all_possible_resources += f_resource_req(attrs)
}
bool inferred_cse_not_possible = LookForBadResources(all_possible_resources)

Regarding the actual name 'THasDeterministicOutput', might this be confused with the cuda notion of deterministic algos? The Convolution op would be tagged with THasDeterministicOutput=1, but would not truly be deterministic if a convolution algo with atomics is used.

Copy link
Contributor

@DickJC123 DickJC123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've supplied commits to @ptrendx that have this PR to my liking. LGTM.

@ptrendx ptrendx merged commit 1aa1b5a into apache:master Nov 1, 2019
apeforest pushed a commit that referenced this pull request Nov 6, 2019
* Eliminate common expressions from a graph

* Guarding against optimizing out stateful ops and ops that require
resource

* Fix lint

* Added THasDeterministicOutput to multiple ops

* DDebug eliminate common expr

* Added test

* Expose get_optimized_symbol

* Fix

* Fix 2

* Add doc to the Python call

* Add env var MXNET_ELIMINATE_COMMON_EXPR, default true

* Add comments, improve readability of eliminate_common_expr_pass.cc

* Expand testing

* Lower priority of THasDeterministicOutput attr for equal Node test

* Change mx.gpu() to mx.cpu() in tests

* Skip CSE test on Windows (as env variable setting during test does not work there)

* Add missing import sys

* Add missing import logging
yajiedesign pushed a commit to yajiedesign/mxnet that referenced this pull request Nov 6, 2019
* Eliminate common expressions from a graph

* Guarding against optimizing out stateful ops and ops that require
resource

* Fix lint

* Added THasDeterministicOutput to multiple ops

* DDebug eliminate common expr

* Added test

* Expose get_optimized_symbol

* Fix

* Fix 2

* Add doc to the Python call

* Add env var MXNET_ELIMINATE_COMMON_EXPR, default true

* Add comments, improve readability of eliminate_common_expr_pass.cc

* Expand testing

* Lower priority of THasDeterministicOutput attr for equal Node test

* Change mx.gpu() to mx.cpu() in tests

* Skip CSE test on Windows (as env variable setting during test does not work there)

* Add missing import sys

* Add missing import logging
anirudh2290 pushed a commit to anirudh2290/mxnet that referenced this pull request Nov 14, 2019
* Eliminate common expressions from a graph

* Guarding against optimizing out stateful ops and ops that require
resource

* Fix lint

* Added THasDeterministicOutput to multiple ops

* DDebug eliminate common expr

* Added test

* Expose get_optimized_symbol

* Fix

* Fix 2

* Add doc to the Python call

* Add env var MXNET_ELIMINATE_COMMON_EXPR, default true

* Add comments, improve readability of eliminate_common_expr_pass.cc

* Expand testing

* Lower priority of THasDeterministicOutput attr for equal Node test

* Change mx.gpu() to mx.cpu() in tests

* Skip CSE test on Windows (as env variable setting during test does not work there)

* Add missing import sys

* Add missing import logging
ArmageddonKnight pushed a commit to UofT-EcoSystem/incubator-mxnet that referenced this pull request Feb 1, 2020
* Add cached op threadsafe version with corresponding C APIs, CPP Package changes, CI changes and tests

* Fix download cmd in runtime_functions

* Add CI changes

* Add stage

Fix indentation

* Fix lint

* Change to DEFAULT for C API

* Fix mxnet_unit_tests path

* export correct LD_LIBRARY_PATH

* Add cpp include dirs

* Build test with USE_CPP_PACKAGE

* Add cached op threadsafe version with corresponding C APIs, CPP Package changes, CI changes and tests

* Fix download cmd in runtime_functions

* Merge

* change mkldnn lib name

* Add static_alloc, static_Shape support

* Address review comments

* Make GetCachedOpThreadSafeState similar to cached_op

* Address review comments: comments for locking strategy

* multithreaded inference tutorial

* [Estimator] handle composite metrics in estimator (apache#16676)

* handle composite metrics in estimator

* fix composite metric case in handlers

* remove unused import

* [Estimator] refactor estimator to allow overriding evaluate/fit of a batch (apache#16678)

* refactor estimator to allow overriding evaluate/fit of a batch

* add doc to explain call structure and how to override

* fix and doc

* Pointwise fusion for GPU (apache#15167)

* Beginning of RTC of pointwise ops

* Code generation from the given JSON

* add initial simple_partition_pass and use it for pointwise fusion

* fix the fusion, use a symbol.Copy() at the beginning of binding function, use the name of input nodes in the cuda code

* Fixes

* Adding support for attribute inference for backward nodes when fusing

* keep proper input ordering for fused Op

* instantiate the indexed_graph before starting the subgraph replacement, return a new graph to reset the indexed_graph

* Fuse backward

* fix ordering of subgraph node inputs using subgraph topological ordering instead of main graph topological ordering, add tvm.patch

* excluse forward node fusion during the fusion of the nodes in the backward graph

* Dealing with fused backward nodes inferattr

* use subgraph.indexed_graph() instead of main for _FusedOpHelper nodes node_id, invert control_deps loop to modify topology of subgraph before calling its indexed_graph(), check that all node of the first DFSVisit are actually in the subgraph

* Adding support for other reqs in codegen

* Fix

* Cleaning

* Change the TVM submodule

* More cleaning

* Making linter happy

* Do fusion only if default context is GPU

* Fixes for tests
Add powerscalar and rpowerscalar, fix return type of zero and one
Cleaning, fixing lint
Go back to proper TVM submodule

* Fix the TVM commit

* Fix lint

* Guard fusion with MXNET_USE_CUDA

* Fix

* Fix clang-tidy

* Add erf and erfinv backward

* Gluon support for fusion

* Cleaning

* Cleaning and allow shape/type change in FusedOp

* Fixing Gluon bugs

* Fixing after rebase

* Fixing race condition and guarding against races when using NVRTC

* Cleaning and renaming FusedOp to _FusedOp

* Going easy on Windows compiler

* Disable fusion on Windows for now

* Refactor InferAttr and InferShapeAttr

* Added slice and half2 support to FusedOp

* Fix lint errors

* Added multiple types support for vector loading/storing

* add slice fusion when it's at the beginning of subgraphs

* Removed constant ndim assumption in fused op

* Fix memory alignment issue in slice for FusedOp

* Fixes

* Fix lint errors

* Do not include cuda_fp16.h

* Refactor fused op op lists

* Make linter happy

* Changes from review

* Fixes after rebase

* Expand FusedOp support for slice

* Fix for fp16 _zeros and _ones

* Fix

* Moving aux functions to unnamed namespace and detail namespace -> fusion
namespace

* Disabling fusion if it alters topological order of inputs

* Print code only when env variable is set

* Fix

* Fix lint and 2 tests that specify the same names for multiple inputs

* Fixes from review and disabling fusion of slice with non-default step

* Add amp_cast to fusion, fixes

* Add amp_multicast and its backward to the list of support ops

* Apply wording suggestions from code review

Co-Authored-By: Aaron Markham <[email protected]>

* Apply wording suggestions from code review

Co-Authored-By: Aaron Markham <[email protected]>

* Make clearer comment

* Adding punctuation and capitalization to \brief descriptions

* Fix

* Fix

* Add backward_cast to fusion

* Adding unittests for fusion. Fix for erfinv_grad

* Adding slice ops and add_n to tests

* Fixes from review

* Setting inplace option

* Fix lint

* Storing double in half

* Retrigger CI

* Slight relaxing of the relative tolerance in the test

* Move the env variable check to the end

* Fix a race condition between InferShape and scheduled Forward

* Fix flakey test_fusion test involving fp32 erfinv op.

* Fix from review

* Added broadcast_like and slice_like to fused op

* Minor fix and cleanup

* Added negative axis support in slice_axis, temporarily disabled fusion of slice_like and broadcast_like

* Added axes support to slice_like

* Added axis support to broadcast_like

* Add fast_load_slice function to fused op code

* Added runtime switch for choosing fast and slow slice kernel

* Fix lint and warning

* Going easy on Windows compiler (again)

* Fix slice_like

* Debug broadcast_like fusion

* Fix lint

* Fix lint

* Trigger CI

* Get rid of the initializer list

* Fix backward calls with different gradient type

* avoid cycle when adding node specific for inputs of subgraph for pointwise fusion

* Fix lint

* Add namespace to the fusion implementations

* Set launch bounds on the fused kernel

* Fix NumPy tests

* Test showcasing an issue fixed in PR apache#16553

* Cast scalarts to FP32 and perform (a*1.0/b) instead of (a/b)

Fix lint errors

Fix lint

* Fix a bug in cycle detection for inputs only op in pointwise fusion

* Add comments to simple_partition_pass.h file

* fix install dir (apache#16690)

* [numpy] add numpy operator : append (apache#16564)

* add operator : append ; fix op concatenate when axis = None

* pylint disable

remove mistake

disable pylint

* Initializer.__eq__ (apache#16680)

* fix binary dependencies in CD and nightly (apache#16693)

* [MKL-DNN] Add mxnet mkldnn cmake tutorial (apache#16688)

* add mxnet mkldnn cmake instruction

* imporve doc

* OMP->OpenMP

* Revert "[MKLDNN]Fix reorder2default (apache#16602)" (apache#16697)

This reverts commit dd4eaf5.

* [Estimator] refactor estimator and clarify docs (apache#16694)

* refactor estimator and clarify docs

* fix info message and test

* clean up after releasing logging handler

* Eliminate common expressions (apache#15657)

* Eliminate common expressions from a graph

* Guarding against optimizing out stateful ops and ops that require
resource

* Fix lint

* Added THasDeterministicOutput to multiple ops

* DDebug eliminate common expr

* Added test

* Expose get_optimized_symbol

* Fix

* Fix 2

* Add doc to the Python call

* Add env var MXNET_ELIMINATE_COMMON_EXPR, default true

* Add comments, improve readability of eliminate_common_expr_pass.cc

* Expand testing

* Lower priority of THasDeterministicOutput attr for equal Node test

* Change mx.gpu() to mx.cpu() in tests

* Skip CSE test on Windows (as env variable setting during test does not work there)

* Add missing import sys

* Add missing import logging

* Backport of apache#16711, apache#16737, apache#16408 to 1.6 branch (apache#16763)

* support mixed-precision true_divide (apache#16711)

* [MKLDNN] use dim_t instead of int in slice/transpose operators (apache#16737)

* use dim_t instead of int

* fix same issue in pooling

* rebase code

* trigger CI

* Add MXNet Ops for fast multihead attention (apache#16408)

* add MXNet Ops for fast multihead attention

* add cutlass as 3rdparty dependency

* add cutlass to compilation flags

* remove all cutlass stuff

* add better error message and description and remove cutlass from compilation flags

* change credit for the approach since the code have changed

* fix typos

* correct another typo

* Add all the cuda/cublas helper functions

* remove tests using kAddTo

* only use cublasStridedBatchedGemm if CUDA >= 9.1

* add equivalent mxnet code in description of mha ops

* remove a wrong copy-paste

* add _contrib for namespace and add GPU only on description

* add warning in bwd_ignore_zero_init description, also test with fp32

* add error return if bwd_ignore_zero_init is used without MXNET_EXEC_ENABLE_ADDTO

* remove std::move for clang

* remove bwd_ignore_zero_init flag

* remove bwd_ignore_zero_init in test_operator_gpu.py

* fix typo

* fix another typo

* Removed unrelated test

* Add example and documentation for multi threaded inference

* Add LICENSE

* Add get_model.py

* Add license for README

* Refactor cached op and cached op threadsafe

* Add limitation

* Add tests for naive engine

* Add latest test changes

* Thread Safety tests in NaiveEngine mode

* Thread Safety tests update

* Update thread safety tests, add unsupported use cases

* Changes to doc and refactor

* Fix todo owner, indentation and mx_float->float

* Refactor cached op code, remove num_threads arg from example

* Fix lint

* Fix warning

* Add back cython, required for unix-gpu build

* Fix for windows

* Add bulking support for thread safe cached op version

* Add support for subgraph testing

* import mxnet before calling get_backend_symbol

* Fix symbol json name

* Refactor DynamicForward

* Add comments

* Add DMLC_ATTRIBUTE_UNUSED

* Fix use_naive_run issue

* Fix lint

* Revert unittest_cpp to old test since it doesnt test thread safety

* Fix doc

Co-authored-by: Sheng Zha <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Co-authored-by: Tao Lv <[email protected]>
Co-authored-by: JiangZhaoh <[email protected]>
Co-authored-by: Leonard Lausen <[email protected]>
Co-authored-by: Xinyu Chen <[email protected]>
Co-authored-by: Zhennan Qin <[email protected]>
zheyuye pushed a commit to zheyuye/incubator-mxnet that referenced this pull request Feb 19, 2020
* Add cached op threadsafe version with corresponding C APIs, CPP Package changes, CI changes and tests

* Fix download cmd in runtime_functions

* Add CI changes

* Add stage

Fix indentation

* Fix lint

* Change to DEFAULT for C API

* Fix mxnet_unit_tests path

* export correct LD_LIBRARY_PATH

* Add cpp include dirs

* Build test with USE_CPP_PACKAGE

* Add cached op threadsafe version with corresponding C APIs, CPP Package changes, CI changes and tests

* Fix download cmd in runtime_functions

* Merge

* change mkldnn lib name

* Add static_alloc, static_Shape support

* Address review comments

* Make GetCachedOpThreadSafeState similar to cached_op

* Address review comments: comments for locking strategy

* multithreaded inference tutorial

* [Estimator] handle composite metrics in estimator (apache#16676)

* handle composite metrics in estimator

* fix composite metric case in handlers

* remove unused import

* [Estimator] refactor estimator to allow overriding evaluate/fit of a batch (apache#16678)

* refactor estimator to allow overriding evaluate/fit of a batch

* add doc to explain call structure and how to override

* fix and doc

* Pointwise fusion for GPU (apache#15167)

* Beginning of RTC of pointwise ops

* Code generation from the given JSON

* add initial simple_partition_pass and use it for pointwise fusion

* fix the fusion, use a symbol.Copy() at the beginning of binding function, use the name of input nodes in the cuda code

* Fixes

* Adding support for attribute inference for backward nodes when fusing

* keep proper input ordering for fused Op

* instantiate the indexed_graph before starting the subgraph replacement, return a new graph to reset the indexed_graph

* Fuse backward

* fix ordering of subgraph node inputs using subgraph topological ordering instead of main graph topological ordering, add tvm.patch

* excluse forward node fusion during the fusion of the nodes in the backward graph

* Dealing with fused backward nodes inferattr

* use subgraph.indexed_graph() instead of main for _FusedOpHelper nodes node_id, invert control_deps loop to modify topology of subgraph before calling its indexed_graph(), check that all node of the first DFSVisit are actually in the subgraph

* Adding support for other reqs in codegen

* Fix

* Cleaning

* Change the TVM submodule

* More cleaning

* Making linter happy

* Do fusion only if default context is GPU

* Fixes for tests
Add powerscalar and rpowerscalar, fix return type of zero and one
Cleaning, fixing lint
Go back to proper TVM submodule

* Fix the TVM commit

* Fix lint

* Guard fusion with MXNET_USE_CUDA

* Fix

* Fix clang-tidy

* Add erf and erfinv backward

* Gluon support for fusion

* Cleaning

* Cleaning and allow shape/type change in FusedOp

* Fixing Gluon bugs

* Fixing after rebase

* Fixing race condition and guarding against races when using NVRTC

* Cleaning and renaming FusedOp to _FusedOp

* Going easy on Windows compiler

* Disable fusion on Windows for now

* Refactor InferAttr and InferShapeAttr

* Added slice and half2 support to FusedOp

* Fix lint errors

* Added multiple types support for vector loading/storing

* add slice fusion when it's at the beginning of subgraphs

* Removed constant ndim assumption in fused op

* Fix memory alignment issue in slice for FusedOp

* Fixes

* Fix lint errors

* Do not include cuda_fp16.h

* Refactor fused op op lists

* Make linter happy

* Changes from review

* Fixes after rebase

* Expand FusedOp support for slice

* Fix for fp16 _zeros and _ones

* Fix

* Moving aux functions to unnamed namespace and detail namespace -> fusion
namespace

* Disabling fusion if it alters topological order of inputs

* Print code only when env variable is set

* Fix

* Fix lint and 2 tests that specify the same names for multiple inputs

* Fixes from review and disabling fusion of slice with non-default step

* Add amp_cast to fusion, fixes

* Add amp_multicast and its backward to the list of support ops

* Apply wording suggestions from code review

Co-Authored-By: Aaron Markham <[email protected]>

* Apply wording suggestions from code review

Co-Authored-By: Aaron Markham <[email protected]>

* Make clearer comment

* Adding punctuation and capitalization to \brief descriptions

* Fix

* Fix

* Add backward_cast to fusion

* Adding unittests for fusion. Fix for erfinv_grad

* Adding slice ops and add_n to tests

* Fixes from review

* Setting inplace option

* Fix lint

* Storing double in half

* Retrigger CI

* Slight relaxing of the relative tolerance in the test

* Move the env variable check to the end

* Fix a race condition between InferShape and scheduled Forward

* Fix flakey test_fusion test involving fp32 erfinv op.

* Fix from review

* Added broadcast_like and slice_like to fused op

* Minor fix and cleanup

* Added negative axis support in slice_axis, temporarily disabled fusion of slice_like and broadcast_like

* Added axes support to slice_like

* Added axis support to broadcast_like

* Add fast_load_slice function to fused op code

* Added runtime switch for choosing fast and slow slice kernel

* Fix lint and warning

* Going easy on Windows compiler (again)

* Fix slice_like

* Debug broadcast_like fusion

* Fix lint

* Fix lint

* Trigger CI

* Get rid of the initializer list

* Fix backward calls with different gradient type

* avoid cycle when adding node specific for inputs of subgraph for pointwise fusion

* Fix lint

* Add namespace to the fusion implementations

* Set launch bounds on the fused kernel

* Fix NumPy tests

* Test showcasing an issue fixed in PR apache#16553

* Cast scalarts to FP32 and perform (a*1.0/b) instead of (a/b)

Fix lint errors

Fix lint

* Fix a bug in cycle detection for inputs only op in pointwise fusion

* Add comments to simple_partition_pass.h file

* fix install dir (apache#16690)

* [numpy] add numpy operator : append (apache#16564)

* add operator : append ; fix op concatenate when axis = None

* pylint disable

remove mistake

disable pylint

* Initializer.__eq__ (apache#16680)

* fix binary dependencies in CD and nightly (apache#16693)

* [MKL-DNN] Add mxnet mkldnn cmake tutorial (apache#16688)

* add mxnet mkldnn cmake instruction

* imporve doc

* OMP->OpenMP

* Revert "[MKLDNN]Fix reorder2default (apache#16602)" (apache#16697)

This reverts commit dd4eaf5.

* [Estimator] refactor estimator and clarify docs (apache#16694)

* refactor estimator and clarify docs

* fix info message and test

* clean up after releasing logging handler

* Eliminate common expressions (apache#15657)

* Eliminate common expressions from a graph

* Guarding against optimizing out stateful ops and ops that require
resource

* Fix lint

* Added THasDeterministicOutput to multiple ops

* DDebug eliminate common expr

* Added test

* Expose get_optimized_symbol

* Fix

* Fix 2

* Add doc to the Python call

* Add env var MXNET_ELIMINATE_COMMON_EXPR, default true

* Add comments, improve readability of eliminate_common_expr_pass.cc

* Expand testing

* Lower priority of THasDeterministicOutput attr for equal Node test

* Change mx.gpu() to mx.cpu() in tests

* Skip CSE test on Windows (as env variable setting during test does not work there)

* Add missing import sys

* Add missing import logging

* Backport of apache#16711, apache#16737, apache#16408 to 1.6 branch (apache#16763)

* support mixed-precision true_divide (apache#16711)

* [MKLDNN] use dim_t instead of int in slice/transpose operators (apache#16737)

* use dim_t instead of int

* fix same issue in pooling

* rebase code

* trigger CI

* Add MXNet Ops for fast multihead attention (apache#16408)

* add MXNet Ops for fast multihead attention

* add cutlass as 3rdparty dependency

* add cutlass to compilation flags

* remove all cutlass stuff

* add better error message and description and remove cutlass from compilation flags

* change credit for the approach since the code have changed

* fix typos

* correct another typo

* Add all the cuda/cublas helper functions

* remove tests using kAddTo

* only use cublasStridedBatchedGemm if CUDA >= 9.1

* add equivalent mxnet code in description of mha ops

* remove a wrong copy-paste

* add _contrib for namespace and add GPU only on description

* add warning in bwd_ignore_zero_init description, also test with fp32

* add error return if bwd_ignore_zero_init is used without MXNET_EXEC_ENABLE_ADDTO

* remove std::move for clang

* remove bwd_ignore_zero_init flag

* remove bwd_ignore_zero_init in test_operator_gpu.py

* fix typo

* fix another typo

* Removed unrelated test

* Add example and documentation for multi threaded inference

* Add LICENSE

* Add get_model.py

* Add license for README

* Refactor cached op and cached op threadsafe

* Add limitation

* Add tests for naive engine

* Add latest test changes

* Thread Safety tests in NaiveEngine mode

* Thread Safety tests update

* Update thread safety tests, add unsupported use cases

* Changes to doc and refactor

* Fix todo owner, indentation and mx_float->float

* Refactor cached op code, remove num_threads arg from example

* Fix lint

* Fix warning

* Add back cython, required for unix-gpu build

* Fix for windows

* Add bulking support for thread safe cached op version

* Add support for subgraph testing

* import mxnet before calling get_backend_symbol

* Fix symbol json name

* Refactor DynamicForward

* Add comments

* Add DMLC_ATTRIBUTE_UNUSED

* Fix use_naive_run issue

* Fix lint

* Revert unittest_cpp to old test since it doesnt test thread safety

* Fix doc

Co-authored-by: Sheng Zha <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Co-authored-by: Tao Lv <[email protected]>
Co-authored-by: JiangZhaoh <[email protected]>
Co-authored-by: Leonard Lausen <[email protected]>
Co-authored-by: Xinyu Chen <[email protected]>
Co-authored-by: Zhennan Qin <[email protected]>
rondogency pushed a commit to rondogency/incubator-mxnet that referenced this pull request Jul 2, 2020
* Add cached op threadsafe version with corresponding C APIs, CPP Package changes, CI changes and tests

* Fix download cmd in runtime_functions

* Add CI changes

* Add stage

Fix indentation

* Fix lint

* Change to DEFAULT for C API

* Fix mxnet_unit_tests path

* export correct LD_LIBRARY_PATH

* Add cpp include dirs

* Build test with USE_CPP_PACKAGE

* Add cached op threadsafe version with corresponding C APIs, CPP Package changes, CI changes and tests

* Fix download cmd in runtime_functions

* Merge

* change mkldnn lib name

* Add static_alloc, static_Shape support

* Address review comments

* Make GetCachedOpThreadSafeState similar to cached_op

* Address review comments: comments for locking strategy

* multithreaded inference tutorial

* [Estimator] handle composite metrics in estimator (apache#16676)

* handle composite metrics in estimator

* fix composite metric case in handlers

* remove unused import

* [Estimator] refactor estimator to allow overriding evaluate/fit of a batch (apache#16678)

* refactor estimator to allow overriding evaluate/fit of a batch

* add doc to explain call structure and how to override

* fix and doc

* Pointwise fusion for GPU (apache#15167)

* Beginning of RTC of pointwise ops

* Code generation from the given JSON

* add initial simple_partition_pass and use it for pointwise fusion

* fix the fusion, use a symbol.Copy() at the beginning of binding function, use the name of input nodes in the cuda code

* Fixes

* Adding support for attribute inference for backward nodes when fusing

* keep proper input ordering for fused Op

* instantiate the indexed_graph before starting the subgraph replacement, return a new graph to reset the indexed_graph

* Fuse backward

* fix ordering of subgraph node inputs using subgraph topological ordering instead of main graph topological ordering, add tvm.patch

* excluse forward node fusion during the fusion of the nodes in the backward graph

* Dealing with fused backward nodes inferattr

* use subgraph.indexed_graph() instead of main for _FusedOpHelper nodes node_id, invert control_deps loop to modify topology of subgraph before calling its indexed_graph(), check that all node of the first DFSVisit are actually in the subgraph

* Adding support for other reqs in codegen

* Fix

* Cleaning

* Change the TVM submodule

* More cleaning

* Making linter happy

* Do fusion only if default context is GPU

* Fixes for tests
Add powerscalar and rpowerscalar, fix return type of zero and one
Cleaning, fixing lint
Go back to proper TVM submodule

* Fix the TVM commit

* Fix lint

* Guard fusion with MXNET_USE_CUDA

* Fix

* Fix clang-tidy

* Add erf and erfinv backward

* Gluon support for fusion

* Cleaning

* Cleaning and allow shape/type change in FusedOp

* Fixing Gluon bugs

* Fixing after rebase

* Fixing race condition and guarding against races when using NVRTC

* Cleaning and renaming FusedOp to _FusedOp

* Going easy on Windows compiler

* Disable fusion on Windows for now

* Refactor InferAttr and InferShapeAttr

* Added slice and half2 support to FusedOp

* Fix lint errors

* Added multiple types support for vector loading/storing

* add slice fusion when it's at the beginning of subgraphs

* Removed constant ndim assumption in fused op

* Fix memory alignment issue in slice for FusedOp

* Fixes

* Fix lint errors

* Do not include cuda_fp16.h

* Refactor fused op op lists

* Make linter happy

* Changes from review

* Fixes after rebase

* Expand FusedOp support for slice

* Fix for fp16 _zeros and _ones

* Fix

* Moving aux functions to unnamed namespace and detail namespace -> fusion
namespace

* Disabling fusion if it alters topological order of inputs

* Print code only when env variable is set

* Fix

* Fix lint and 2 tests that specify the same names for multiple inputs

* Fixes from review and disabling fusion of slice with non-default step

* Add amp_cast to fusion, fixes

* Add amp_multicast and its backward to the list of support ops

* Apply wording suggestions from code review

Co-Authored-By: Aaron Markham <[email protected]>

* Apply wording suggestions from code review

Co-Authored-By: Aaron Markham <[email protected]>

* Make clearer comment

* Adding punctuation and capitalization to \brief descriptions

* Fix

* Fix

* Add backward_cast to fusion

* Adding unittests for fusion. Fix for erfinv_grad

* Adding slice ops and add_n to tests

* Fixes from review

* Setting inplace option

* Fix lint

* Storing double in half

* Retrigger CI

* Slight relaxing of the relative tolerance in the test

* Move the env variable check to the end

* Fix a race condition between InferShape and scheduled Forward

* Fix flakey test_fusion test involving fp32 erfinv op.

* Fix from review

* Added broadcast_like and slice_like to fused op

* Minor fix and cleanup

* Added negative axis support in slice_axis, temporarily disabled fusion of slice_like and broadcast_like

* Added axes support to slice_like

* Added axis support to broadcast_like

* Add fast_load_slice function to fused op code

* Added runtime switch for choosing fast and slow slice kernel

* Fix lint and warning

* Going easy on Windows compiler (again)

* Fix slice_like

* Debug broadcast_like fusion

* Fix lint

* Fix lint

* Trigger CI

* Get rid of the initializer list

* Fix backward calls with different gradient type

* avoid cycle when adding node specific for inputs of subgraph for pointwise fusion

* Fix lint

* Add namespace to the fusion implementations

* Set launch bounds on the fused kernel

* Fix NumPy tests

* Test showcasing an issue fixed in PR apache#16553

* Cast scalarts to FP32 and perform (a*1.0/b) instead of (a/b)

Fix lint errors

Fix lint

* Fix a bug in cycle detection for inputs only op in pointwise fusion

* Add comments to simple_partition_pass.h file

* fix install dir (apache#16690)

* [numpy] add numpy operator : append (apache#16564)

* add operator : append ; fix op concatenate when axis = None

* pylint disable

remove mistake

disable pylint

* Initializer.__eq__ (apache#16680)

* fix binary dependencies in CD and nightly (apache#16693)

* [MKL-DNN] Add mxnet mkldnn cmake tutorial (apache#16688)

* add mxnet mkldnn cmake instruction

* imporve doc

* OMP->OpenMP

* Revert "[MKLDNN]Fix reorder2default (apache#16602)" (apache#16697)

This reverts commit dd4eaf5.

* [Estimator] refactor estimator and clarify docs (apache#16694)

* refactor estimator and clarify docs

* fix info message and test

* clean up after releasing logging handler

* Eliminate common expressions (apache#15657)

* Eliminate common expressions from a graph

* Guarding against optimizing out stateful ops and ops that require
resource

* Fix lint

* Added THasDeterministicOutput to multiple ops

* DDebug eliminate common expr

* Added test

* Expose get_optimized_symbol

* Fix

* Fix 2

* Add doc to the Python call

* Add env var MXNET_ELIMINATE_COMMON_EXPR, default true

* Add comments, improve readability of eliminate_common_expr_pass.cc

* Expand testing

* Lower priority of THasDeterministicOutput attr for equal Node test

* Change mx.gpu() to mx.cpu() in tests

* Skip CSE test on Windows (as env variable setting during test does not work there)

* Add missing import sys

* Add missing import logging

* Backport of apache#16711, apache#16737, apache#16408 to 1.6 branch (apache#16763)

* support mixed-precision true_divide (apache#16711)

* [MKLDNN] use dim_t instead of int in slice/transpose operators (apache#16737)

* use dim_t instead of int

* fix same issue in pooling

* rebase code

* trigger CI

* Add MXNet Ops for fast multihead attention (apache#16408)

* add MXNet Ops for fast multihead attention

* add cutlass as 3rdparty dependency

* add cutlass to compilation flags

* remove all cutlass stuff

* add better error message and description and remove cutlass from compilation flags

* change credit for the approach since the code have changed

* fix typos

* correct another typo

* Add all the cuda/cublas helper functions

* remove tests using kAddTo

* only use cublasStridedBatchedGemm if CUDA >= 9.1

* add equivalent mxnet code in description of mha ops

* remove a wrong copy-paste

* add _contrib for namespace and add GPU only on description

* add warning in bwd_ignore_zero_init description, also test with fp32

* add error return if bwd_ignore_zero_init is used without MXNET_EXEC_ENABLE_ADDTO

* remove std::move for clang

* remove bwd_ignore_zero_init flag

* remove bwd_ignore_zero_init in test_operator_gpu.py

* fix typo

* fix another typo

* Removed unrelated test

* Add example and documentation for multi threaded inference

* Add LICENSE

* Add get_model.py

* Add license for README

* Refactor cached op and cached op threadsafe

* Add limitation

* Add tests for naive engine

* Add latest test changes

* Thread Safety tests in NaiveEngine mode

* Thread Safety tests update

* Update thread safety tests, add unsupported use cases

* Changes to doc and refactor

* Fix todo owner, indentation and mx_float->float

* Refactor cached op code, remove num_threads arg from example

* Fix lint

* Fix warning

* Add back cython, required for unix-gpu build

* Fix for windows

* Add bulking support for thread safe cached op version

* Add support for subgraph testing

* import mxnet before calling get_backend_symbol

* Fix symbol json name

* Refactor DynamicForward

* Add comments

* Add DMLC_ATTRIBUTE_UNUSED

* Fix use_naive_run issue

* Fix lint

* Revert unittest_cpp to old test since it doesnt test thread safety

* Fix doc

Co-authored-by: Sheng Zha <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Co-authored-by: Tao Lv <[email protected]>
Co-authored-by: JiangZhaoh <[email protected]>
Co-authored-by: Leonard Lausen <[email protected]>
Co-authored-by: Xinyu Chen <[email protected]>
Co-authored-by: Zhennan Qin <[email protected]>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Backend Issues related to the backend of MXNet pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants