-
Notifications
You must be signed in to change notification settings - Fork 6.8k
MXNet AMP (automatic mixed precision) #14173
Conversation
.gitmodules
Outdated
@@ -25,7 +25,7 @@ | |||
url = https://github.com/dmlc/cub | |||
[submodule "3rdparty/tvm"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we need to upgrade to the latest commit of tvm (which depends on it own version of dmlc-core), it is necessary to upgrade the dmlc-core submodule in mxnet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a. any nnvm header that is included in MXNet .cc file will use the MXNet's dmlc-core
b. I'm not sure if the make of NNVM is done using its own dmlc-core - it is possible that the env is set up so all components use the same (MXNet's) dmlc-core. Otherwise you could end up with some binary incompatibilities (e.g. if dmlc-core changes some struct and it is passed around between NNVM and MXNet you don't want it to be interpreted differently).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tvm now comes with a notice file, so we'd need to add that to mxnet's notice too per apache license.
@ptrendx Thanks for the contribution! @mxnet-label-bot add [pr-work-in-progress ] |
python/mxnet/amp/amp.py
Outdated
if (cond_arg[0] not in kwargs or | ||
kwargs[cond_arg[0]] not in cond_arg[1]): | ||
return f(*args, **kwargs) | ||
new_args = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype), args)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If an fp16 output is used for multiple fp32 operators, can we cast the fp16 only once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would require a graph pass approach instead of simple function substitution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will degrade the performance a lot so I suggest using the graph pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a graph pass may be the next step for performance tuning but it definitely is not necessary for the feature to be useful. I tested all the models from GluonCV and in none of them I saw a need for it (the unnecessary casts typically take <1% of the total time).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good :) Could you share your data and the plan as my comments in the below?
@ptrendx @eric-haibin-lin could you share the total picture (or methodology/sw stack) of AMP integration? I'd like to understand in the high level before going to details. |
if not _amp_initialized: | ||
_amp_initialized = True | ||
logging.info("Using AMP") | ||
target_dtype = np.dtype(target_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert that target_dtype is float16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
"""Lists of functions whitelisted/blacklisted for automatic mixed precision in symbol API.""" | ||
|
||
# Functions that should be cast to lower precision | ||
TARGET_DTYPE_FUNCS = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can still be called FP16_FUNCS and FP16_FP32_FUNCS right ? If bfloat16 support is added in future there will be additional lists added here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the previous comment you said "change the lists to have target_dtype_list, target_dtype_fp32_list." and that is why I did that change. I can revert it - either way is fine by me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for not being clear. I meant <target_dtype>_fp32_list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, will change
aux = sym.list_auxiliary_states() | ||
inputs = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype) | ||
if x.name not in aux else x, inputs)) | ||
atomic_sym = sym._gen_atomic_symbol() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically this allows us to cast inputs to the op that were not specified by the user (e.g. you make a convolution via symbolic API, you don't need to pass weights and biases to it, MXNet will generate them). So what we do here is create a symbol using the original function (which will create all the other children), then take those children and cast them. MXNet (nor NNVM) does not allow, however, to manipulate a symbol children after it was created. So what we do is we create a new, atomic symbol (atomic means it does not have any children set yet, but has all the params the same as the symbol used to create it), and populate that symbol with casted inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when pretrained models are loaded under amp.init will it silently fail ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretrained as in symbol? That will not do anything (it needs your PR).
Pretrained as in load_parameters? That will work as expected (I was using it for all my experiments with GluonCV where e.g. SSD uses pretrained RN50 backbone).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this change LGTM. This will be really useful to our users. Thanks a lot for your effort @ptrendx!
tests checking that AMP lists contain only existing ops
because of functions being available only in specific configurations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have unit tests for loss scaler, multi_cast and isfinite ops to make sure they're not broken in the future?
* Beginning of AMP * Optimize noop cast * More operations added * Backward cast * Adding AMPCast and AMPMultiCast * Fix some of lint * Changed symbol wrapper to handle hidden inputs Added PoC of dynamic loss scaling * Moved back to dmlc/tvm repo * fix counter reset to increase loss scale every 2k iterations * Fix indentation * Add contrib from symbol and ndarray to symbol list * Adding where to widest type cast * Do not cast in imperative mode on CPU context * Update dmlc-core to fix unittests * Fix wrapper metadata, fix self handling * Blacklist sync batchnorm (since its implementation is FP32 only) * Fix lint * Enable losses to be tuple * Get rid of AMP handle * Add scaling to Output functions * Fix pylint * Update dmlc-core * Changing prints in AMP to logging.info * NNVM -> MXNet for FInferShape * Bring the inplaceidentity fix to copied pass from NNVM * Added tutorial for AMP * Making Windows compiler happy * Fixes to tutorial * More fixes * Fix lint * Fix * Add amp/index.md to whitelist for tutorial tests * Whitelisting cuDNN RNN * Manual unscale * _internal functions wrapping * Make SymbolFunctor from Symbol * Fix the type infer function of AMP multicast * Added ability to override casting lists * Making clang-tidy and pylint happy * More cleaning * Making clang-tidy really happy * remove amp_cast and amp_multicast before saving the model * Changes from review * Add RemoveAmpCast in a separate c_api function, add the option in symbol.save * add remove_amp_cast option (True by default) to everyway of saving symbol * Fix * First stab at adding the gray list * More ops added * Adding the rest of the functions * Improvements to AMP test * Changing of names and changing wrapping * Moving to contrib * Modifying tutorial for contrib AMP * Removing non existent functions * Fix import in test * Fix lint * Added new functions * Added assert * Fix the unknown ndim in PlanMemory pass * Moving back to FP16_FUNCS and FP16_FP32_FUNCS * Removing unnecessary ops * Adding ops that exist only in some build configurations and removing tests checking that AMP lists contain only existing ops * Removing warning when not every function was found during AMP init because of functions being available only in specific configurations * Add tests and doc * Fix the CPU version of all_finite * Adding test cases for all_finite operator * Add new operators * Fix
Description
Whis is a Work in Progress PR for AMP (automatic mixed precision) support for MXNet, similar to pyTorch version found in https://github.com/NVIDIA/apex.
This PR relies on multiple other PRs and bug fixes, listed in Comments section.
Dynamic loss scaling part done by @Caenorst (commits were squashed for easier rebasing).
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
amp_cast
andamp_multicast
that handle casting between FP16/FP32 when necessary and do not change other types. They are optimized to not do anything if the input is already in the proper type.Comments
FYI @eric-haibin-lin @szha