-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Improve AMP, bf16 support. Support oneDNN ops in AMP #20753
Conversation
Hey @PawelGlomski-Intel , Thanks for submitting the PR
CI supported jobs: [clang, windows-cpu, miscellaneous, centos-gpu, windows-gpu, unix-cpu, sanity, unix-gpu, centos-cpu, website, edge] Note: |
0fba44c
to
29ace88
Compare
Hi @ptrendx. Could you please explain what is the use case of |
Hi, the purpose of Not sure what you mean by casting inputs to |
Also, adding @mk-61 to the discussion. |
Thanks a lot, I thought that was the case with Regarding the inference optimization - it indeed also applies to the |
Hmmm, yeah, not 100% sure about that. I assume this is the effect of this |
@szha What is your take on this? |
Jenkins CI successfully triggered : [unix-cpu, clang, centos-gpu] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments - LGTM
@mxnet-bot run ci [unix-gpu, windows-gpu] |
Jenkins CI successfully triggered : [unix-gpu, windows-gpu] |
@mxnet-bot run ci [centos-gpu, clang] |
Jenkins CI successfully triggered : [clang, centos-gpu] |
Description
General AMP improvements, extending bf16 type support, and adding oneDNN ops support to AMP (and amp_cast fuse)
Backward-incompatible changes
The current model conversion API (
convert_symbol
,convert_model
, andconvert_hybrid_block
) works without type information and assumes that everything works on floats. To ensure the correctness of conversion of any model and operator, the complete type information must be present. This PR (in addition to numerous fixes) introduces backward-incompatible changes in order to eliminate any assumptions.Changes in the user API
The user will have to provide information about the data types of the model inputs.
In MXNet 2.0, the main function for model conversion should be
convert_hybrid_block
. Since the type information will usually be derived from some data example, the user should provide the data example itself. The previous implementation required the model to be hybridized and run at least once before conversion (so the graph is created). This requirement can be now removed since with a data example, the graph can be generated insideconvert_hybrid_block
(if it is not already present).Additionally, the
cast_optional_params
was renamed to:cast_params_offline
which I believe is more descriptive.Before:
After:
convert_model
andconvert_symbol
must also be changed accordingly:Before:
After:
Here the
input_dtypes
is expected to be a dictionary, mapping names of model inputs to their types.Before:
After:
Similarly,
convert_symbol
will now requireinput_dtypes
as well asparam_dtypes
, which will also be a dictionary, mapping names of model parameters to their types.Changes in behavior
Since we now have full information about the model types, the operator
amp_multicast
, which previously introduced a lot of uncertainty about the types of layers it preceded, and which had to be handled using (excessive)amp_cast
nodes, will never be needed now. It will be replaced by multipleamp_cast
nodes (only for tensors that need the casting), which can be then usually fused with oneDNN-optimized nodes using theONEDNN_AMP
pass (if the model was first optimized for theONEDNN
backend).Because
amp_multicast
will not be used, the problem with offline casting of parameters discussed below is no longer present.