Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

MXNet AMP (automatic mixed precision) #14173

Merged
merged 74 commits into from
May 21, 2019
Merged

Conversation

ptrendx
Copy link
Member

@ptrendx ptrendx commented Feb 15, 2019

Description

Whis is a Work in Progress PR for AMP (automatic mixed precision) support for MXNet, similar to pyTorch version found in https://github.com/NVIDIA/apex.

This PR relies on multiple other PRs and bug fixes, listed in Comments section.

Dynamic loss scaling part done by @Caenorst (commits were squashed for easier rebasing).

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Auditor that enables/disables operations to be done in FP16 automatically. It is implemented via patching MXNet functions in mxnet.symbol and mxnet.ndarray to insert casts to FP16/FP32 where necessary.
  • Operator amp_cast and amp_multicast that handle casting between FP16/FP32 when necessary and do not change other types. They are optimized to not do anything if the input is already in the proper type.
  • Dynamic loss scaling and supporting operators for checking gradients for infs/NaNs and skipping update step if such value is encountered.

Comments

FYI @eric-haibin-lin @szha

@ptrendx ptrendx requested a review from szha as a code owner February 15, 2019 03:55
.gitmodules Outdated
@@ -25,7 +25,7 @@
url = https://github.com/dmlc/cub
[submodule "3rdparty/tvm"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we need to upgrade to the latest commit of tvm (which depends on it own version of dmlc-core), it is necessary to upgrade the dmlc-core submodule in mxnet?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a. any nnvm header that is included in MXNet .cc file will use the MXNet's dmlc-core
b. I'm not sure if the make of NNVM is done using its own dmlc-core - it is possible that the env is set up so all components use the same (MXNet's) dmlc-core. Otherwise you could end up with some binary incompatibilities (e.g. if dmlc-core changes some struct and it is passed around between NNVM and MXNet you don't want it to be interpreted differently).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tvm now comes with a notice file, so we'd need to add that to mxnet's notice too per apache license.

@ankkhedia
Copy link
Contributor

@ptrendx Thanks for the contribution!

@mxnet-label-bot add [pr-work-in-progress ]

@marcoabreu marcoabreu added the pr-work-in-progress PR is still work in progress label Feb 15, 2019
python/mxnet/amp/amp.py Outdated Show resolved Hide resolved
python/mxnet/amp/amp.py Outdated Show resolved Hide resolved
if (cond_arg[0] not in kwargs or
kwargs[cond_arg[0]] not in cond_arg[1]):
return f(*args, **kwargs)
new_args = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype), args))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If an fp16 output is used for multiple fp32 operators, can we cast the fp16 only once?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would require a graph pass approach instead of simple function substitution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will degrade the performance a lot so I suggest using the graph pass.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a graph pass may be the next step for performance tuning but it definitely is not necessary for the feature to be useful. I tested all the models from GluonCV and in none of them I saw a need for it (the unnecessary casts typically take <1% of the total time).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good :) Could you share your data and the plan as my comments in the below?

@pengzhao-intel
Copy link
Contributor

@ptrendx @eric-haibin-lin could you share the total picture (or methodology/sw stack) of AMP integration? I'd like to understand in the high level before going to details.

if not _amp_initialized:
_amp_initialized = True
logging.info("Using AMP")
target_dtype = np.dtype(target_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert that target_dtype is float16

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

"""Lists of functions whitelisted/blacklisted for automatic mixed precision in symbol API."""

# Functions that should be cast to lower precision
TARGET_DTYPE_FUNCS = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can still be called FP16_FUNCS and FP16_FP32_FUNCS right ? If bfloat16 support is added in future there will be additional lists added here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the previous comment you said "change the lists to have target_dtype_list, target_dtype_fp32_list." and that is why I did that change. I can revert it - either way is fine by me

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for not being clear. I meant <target_dtype>_fp32_list

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, will change

aux = sym.list_auxiliary_states()
inputs = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype)
if x.name not in aux else x, inputs))
atomic_sym = sym._gen_atomic_symbol()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically this allows us to cast inputs to the op that were not specified by the user (e.g. you make a convolution via symbolic API, you don't need to pass weights and biases to it, MXNet will generate them). So what we do here is create a symbol using the original function (which will create all the other children), then take those children and cast them. MXNet (nor NNVM) does not allow, however, to manipulate a symbol children after it was created. So what we do is we create a new, atomic symbol (atomic means it does not have any children set yet, but has all the params the same as the symbol used to create it), and populate that symbol with casted inputs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when pretrained models are loaded under amp.init will it silently fail ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretrained as in symbol? That will not do anything (it needs your PR).
Pretrained as in load_parameters? That will work as expected (I was using it for all my experiments with GluonCV where e.g. SSD uses pretrained RN50 backbone).

Copy link
Member

@anirudh2290 anirudh2290 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this change LGTM. This will be really useful to our users. Thanks a lot for your effort @ptrendx!

@ptrendx ptrendx changed the title [WIP] MXNet AMP (automatic mixed precision) MXNet AMP (automatic mixed precision) May 17, 2019
because of functions being available only in specific configurations
Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have unit tests for loss scaler, multi_cast and isfinite ops to make sure they're not broken in the future?

python/mxnet/model.py Show resolved Hide resolved
python/mxnet/model.py Show resolved Hide resolved
python/mxnet/symbol/symbol.py Show resolved Hide resolved
@anirudh2290 anirudh2290 merged commit 5bc08ce into apache:master May 21, 2019
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* Beginning of AMP

* Optimize noop cast

* More operations added

* Backward cast

* Adding AMPCast and AMPMultiCast

* Fix some of lint

* Changed symbol wrapper to handle hidden inputs
Added PoC of dynamic loss scaling

* Moved back to dmlc/tvm repo

* fix counter reset to increase loss scale every 2k iterations

* Fix indentation

* Add contrib from symbol and ndarray to symbol list

* Adding where to widest type cast

* Do not cast in imperative mode on CPU context

* Update dmlc-core to fix unittests

* Fix wrapper metadata, fix self handling

* Blacklist sync batchnorm (since its implementation is FP32 only)

* Fix lint

* Enable losses to be tuple

* Get rid of AMP handle

* Add scaling to Output functions

* Fix pylint

* Update dmlc-core

* Changing prints in AMP to logging.info

* NNVM -> MXNet for FInferShape

* Bring the inplaceidentity fix to copied pass from NNVM

* Added tutorial for AMP

* Making Windows compiler happy

* Fixes to tutorial

* More fixes

* Fix lint

* Fix

* Add amp/index.md to whitelist for tutorial tests

* Whitelisting cuDNN RNN

* Manual unscale

* _internal functions wrapping

* Make SymbolFunctor from Symbol

* Fix the type infer function of AMP multicast

* Added ability to override casting lists

* Making clang-tidy and pylint happy

* More cleaning

* Making clang-tidy really happy

* remove amp_cast and amp_multicast before saving the model

* Changes from review

* Add RemoveAmpCast in a separate c_api function, add the option in symbol.save

* add remove_amp_cast option (True by default) to everyway of saving symbol

* Fix

* First stab at adding the gray list

* More ops added

* Adding the rest of the functions

* Improvements to AMP test

* Changing of names and changing wrapping

* Moving to contrib

* Modifying tutorial for contrib AMP

* Removing non existent functions

* Fix import in test

* Fix lint

* Added new functions

* Added assert

* Fix the unknown ndim in PlanMemory pass

* Moving back to FP16_FUNCS and FP16_FP32_FUNCS

* Removing unnecessary ops

* Adding ops that exist only in some build configurations and removing
tests checking that AMP lists contain only existing ops

* Removing warning when not every function was found during AMP init
because of functions being available only in specific configurations

* Add tests and doc

* Fix the CPU version of all_finite

* Adding test cases for all_finite operator

* Add new operators

* Fix
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
pr-work-in-progress PR is still work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants