Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Improve AMP, bf16 support. Support oneDNN ops in AMP #20753

Merged
merged 40 commits into from
Mar 31, 2022

Conversation

PawelGlomski-Intel
Copy link
Contributor

@PawelGlomski-Intel PawelGlomski-Intel commented Nov 24, 2021

Description

General AMP improvements, extending bf16 type support, and adding oneDNN ops support to AMP (and amp_cast fuse)

Backward-incompatible changes

The current model conversion API (convert_symbol, convert_model, and convert_hybrid_block) works without type information and assumes that everything works on floats. To ensure the correctness of conversion of any model and operator, the complete type information must be present. This PR (in addition to numerous fixes) introduces backward-incompatible changes in order to eliminate any assumptions.

Changes in the user API

The user will have to provide information about the data types of the model inputs.

In MXNet 2.0, the main function for model conversion should be convert_hybrid_block. Since the type information will usually be derived from some data example, the user should provide the data example itself. The previous implementation required the model to be hybridized and run at least once before conversion (so the graph is created). This requirement can be now removed since with a data example, the graph can be generated inside convert_hybrid_block (if it is not already present).

Additionally, the cast_optional_params was renamed to: cast_params_offline which I believe is more descriptive.

Before:

def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
                         fp32_ops=None, conditional_fp32_ops=None,
                         excluded_sym_names=None, device=gpu(0),
                         cast_optional_params=False):

After:

def convert_hybrid_block(block, data_example, target_dtype="float16", target_dtype_ops=None,
                         fp32_ops=None, conditional_fp32_ops=None,
                         excluded_sym_names=[], device=gpu(0),
                         cast_params_offline=False):

convert_model and convert_symbol must also be changed accordingly:

Before:

def convert_model(sym, arg_params, aux_params, target_dtype="float16", target_dtype_ops=None,
                  fp32_ops=None, conditional_fp32_ops=None, excluded_sym_names=None,
                  cast_optional_params=False):

After:

def convert_model(sym, arg_params, aux_params, input_dtypes, target_dtype="float16",
                  target_dtype_ops=None, fp32_ops=None, conditional_fp32_ops=None,
                  excluded_sym_names=[], cast_params_offline=False):

Here the input_dtypes is expected to be a dictionary, mapping names of model inputs to their types.

Before:

def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None,
                   fp32_ops=None, conditional_fp32_ops=None,
                   excluded_sym_names=None, data_names=None,
                   cast_optional_params=False):

After:

def convert_symbol(sym, input_dtypes, param_dtypes, target_dtype="float16", target_dtype_ops=None,
                   fp32_ops=None, conditional_fp32_ops=None, excluded_sym_names=[],
                   cast_params_offline=False):

Similarly, convert_symbol will now require input_dtypes as well as param_dtypes, which will also be a dictionary, mapping names of model parameters to their types.

Changes in behavior

Since we now have full information about the model types, the operator amp_multicast, which previously introduced a lot of uncertainty about the types of layers it preceded, and which had to be handled using (excessive) amp_cast nodes, will never be needed now. It will be replaced by multiple amp_cast nodes (only for tensors that need the casting), which can be then usually fused with oneDNN-optimized nodes using the ONEDNN_AMP pass (if the model was first optimized for the ONEDNN backend).

Because amp_multicast will not be used, the problem with offline casting of parameters discussed below is no longer present.

@mxnet-bot
Copy link

Hey @PawelGlomski-Intel , Thanks for submitting the PR
All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands:

  • To trigger all jobs: @mxnet-bot run ci [all]
  • To trigger specific jobs: @mxnet-bot run ci [job1, job2]

CI supported jobs: [clang, windows-cpu, miscellaneous, centos-gpu, windows-gpu, unix-cpu, sanity, unix-gpu, centos-cpu, website, edge]


Note:
Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin.
All CI tests must pass before the PR can be merged.

@mseth10 mseth10 added the pr-work-in-progress PR is still work in progress label Nov 24, 2021
@PawelGlomski-Intel PawelGlomski-Intel changed the title [WIP] Improve AMP, bf16 support, fuse amp_cast with oneDNN ops [WIP] Improve AMP, bf16 support, support oneDNN ops in AMP Nov 24, 2021
@PawelGlomski-Intel PawelGlomski-Intel changed the title [WIP] Improve AMP, bf16 support, support oneDNN ops in AMP [WIP] Improve AMP, bf16 support. Support oneDNN ops in AMP Nov 24, 2021
@PawelGlomski-Intel PawelGlomski-Intel marked this pull request as draft November 24, 2021 12:11
@PawelGlomski-Intel
Copy link
Contributor Author

PawelGlomski-Intel commented Nov 25, 2021

Hi @ptrendx. Could you please explain what is the use case of amp_multicast and why can't we just use multiple amp_cast nodes? Also, why would we cast the parameters that are inputs of amp_multicast offline to low precision, doesn't that contradict the purpose of using amp_multicast for these cases?

@ptrendx
Copy link
Member

ptrendx commented Nov 25, 2021

Hi, the purpose of amp_multicast to cast the inputs of the ops that take multiple of them (think add for example) to be the same widest precision. We can't use multiple amp_cast for this since this cast happens before the type inference and so we do not know what that widest type is. An example - let's take a + b operation. If both a and b are float16, you do not want to insert any casts. If one of them (let's say a) is float32 and the other (b) is float16, then you want to cast b to float32.

Not sure what you mean by casting inputs to amp_multicast offline. There is an optimization for inference cases to cast parameters to float16, but those do not go into amp_multicast but to the regular amp_cast I believe.

@ptrendx
Copy link
Member

ptrendx commented Nov 25, 2021

Also, adding @mk-61 to the discussion.

@PawelGlomski-Intel
Copy link
Contributor Author

PawelGlomski-Intel commented Nov 26, 2021

Hi, the purpose of amp_multicast to cast the inputs of the ops that take multiple of them (think add for example) to be the same widest precision. We can't use multiple amp_cast for this since this cast happens before the type inference and so we do not know what that widest type is. An example - let's take a + b operation. If both a and b are float16, you do not want to insert any casts. If one of them (let's say a) is float32 and the other (b) is float16, then you want to cast b to float32.

Not sure what you mean by casting inputs to amp_multicast offline. There is an optimization for inference cases to cast parameters to float16, but those do not go into amp_multicast but to the regular amp_cast I believe.

Thanks a lot, I thought that was the case with amp_multicast but wasn't 100% sure.

Regarding the inference optimization - it indeed also applies to the amp_multicast, and it is even tested. I removed this as it seemed unlogical to me, so now these tests fail. Here is a comment about adding this. Was this an incorrect approach and my current version is correct?
Here is one of the tests (BTW, the dtype of a variable doesn't matter at all here, it will always cast these parameters to fp16):
https://github.com/apache/incubator-mxnet/blob/29ace886946941527047dc5deebe5b4b85b5e4cb/tests/python/gpu/test_amp.py#L133-L139

@ptrendx
Copy link
Member

ptrendx commented Nov 27, 2021

Hmmm, yeah, not 100% sure about that. I assume this is the effect of this cast_optional_params option? I think I would remove it, since it kind of defies the purpose of AMP (if you want full fp16 you should just cast to it).

@PawelGlomski-Intel
Copy link
Contributor Author

@szha What is your take on this?
When a model parameter is an input of the amp_multicast node, it will be cast to lp16 (with cast_optional_params set to True). amp_multicast node is added for ops from this list, so their inputs share one (most accurate) dtype. By the definition, such ops should only run on lp16 when all of its inputs are already in low precision, while currently, input parameters are always cast to lp16, even when they are all f32. I don't think it's intuitive.

@mxnet-bot
Copy link

Jenkins CI successfully triggered : [unix-cpu, clang, centos-gpu]

@mseth10 mseth10 added pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-work-in-progress PR is still work in progress labels Mar 28, 2022
Copy link
Contributor

@bgawrych bgawrych left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments - LGTM

src/nnvm/low_precision_pass.cc Outdated Show resolved Hide resolved
src/operator/subgraph/dnnl/dnnl_transformer.cc Outdated Show resolved Hide resolved
tests/python/dnnl/test_amp.py Outdated Show resolved Hide resolved
tests/python/gpu/test_amp.py Outdated Show resolved Hide resolved
tests/python/gpu/test_amp.py Outdated Show resolved Hide resolved
@mseth10 mseth10 added pr-work-in-progress PR is still work in progress and removed pr-awaiting-testing PR is reviewed and waiting CI build and test labels Mar 28, 2022
@PawelGlomski-Intel
Copy link
Contributor Author

@mxnet-bot run ci [unix-gpu, windows-gpu]

@mxnet-bot
Copy link

Jenkins CI successfully triggered : [unix-gpu, windows-gpu]

@mseth10 mseth10 added pr-awaiting-testing PR is reviewed and waiting CI build and test pr-awaiting-review PR is waiting for code review and removed pr-work-in-progress PR is still work in progress pr-awaiting-testing PR is reviewed and waiting CI build and test labels Mar 29, 2022
@mseth10 mseth10 added pr-awaiting-testing PR is reviewed and waiting CI build and test pr-awaiting-review PR is waiting for code review pr-work-in-progress PR is still work in progress and removed pr-awaiting-review PR is waiting for code review pr-awaiting-testing PR is reviewed and waiting CI build and test labels Mar 30, 2022
@PawelGlomski-Intel
Copy link
Contributor Author

@mxnet-bot run ci [centos-gpu, clang]

@mxnet-bot
Copy link

Jenkins CI successfully triggered : [clang, centos-gpu]

@mseth10 mseth10 added pr-awaiting-testing PR is reviewed and waiting CI build and test pr-awaiting-review PR is waiting for code review and removed pr-work-in-progress PR is still work in progress pr-awaiting-testing PR is reviewed and waiting CI build and test labels Mar 30, 2022
@bgawrych bgawrych merged commit e8ff13c into apache:master Mar 31, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants