Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[oneDNN] Initial bf16 amp integration #31093

Merged
merged 33 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
776eca3
Initial bf16 amp integration
arlesniak Feb 21, 2021
36a1257
Updates for CI
arlesniak Feb 24, 2021
9fc6899
More updates for CI
arlesniak Feb 25, 2021
1261873
More updates for CI
arlesniak Feb 25, 2021
d532ca0
Added test for bf16_utils
arlesniak Feb 25, 2021
deb275f
Added test for bf16_utils
arlesniak Feb 25, 2021
a3fdf51
Changes for CI
arlesniak Feb 26, 2021
204760a
Changes for CI, more tests
arlesniak Feb 28, 2021
4e5555f
Changes for CI
arlesniak Mar 1, 2021
3c9e45f
Changes for CI
arlesniak Mar 1, 2021
a6a9518
Changes for CI
arlesniak Mar 1, 2021
8c1b7b7
Improvements
arlesniak Mar 3, 2021
654904b
Refactor
arlesniak Mar 10, 2021
213737a
Refactor
arlesniak Mar 10, 2021
be8e759
Changes for CI
arlesniak Mar 10, 2021
fc665a5
Changes for CI
arlesniak Mar 10, 2021
7098724
More tests
arlesniak Mar 10, 2021
ed6cd06
More tests, introduced bf16 scale op
arlesniak Mar 10, 2021
8d65d44
Changes for CI
arlesniak Mar 11, 2021
f4f958b
Changes for CI
arlesniak Mar 11, 2021
6117982
Changes for CI
arlesniak Mar 12, 2021
72405cf
Changes for CI
arlesniak Mar 12, 2021
c213d08
Changes for CI
arlesniak Mar 17, 2021
b6c4ad2
Changes for CI
arlesniak Mar 17, 2021
0bda415
Changes for CI
arlesniak Mar 17, 2021
68ea1f9
Changes for CI
arlesniak Mar 17, 2021
09a2f47
Changes for CI
arlesniak Mar 18, 2021
274fa0b
Changes for CI
arlesniak Mar 18, 2021
11dc278
Changes for CI
arlesniak Mar 18, 2021
bd2dea8
Changes to trigger blocked CIs
arlesniak Mar 18, 2021
f07ca15
Changes for CI
arlesniak Mar 18, 2021
d8f810c
Less lines in amp_lists.py
arlesniak Mar 19, 2021
6fb0cb8
Changes after review
arlesniak Mar 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/operators/cast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,6 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
ops::CastOpKernel<CPU, paddle::platform::complex64>,
ops::CastOpKernel<CPU, paddle::platform::complex128>);
2 changes: 2 additions & 0 deletions paddle/fluid/operators/scale_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker,
REGISTER_OP_CPU_KERNEL(
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int16_t>,
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/contrib/mixed_precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from .fp16_lists import *
from . import fp16_utils
from .fp16_utils import *
from . import bf16
from .bf16 import *

__all__ = decorator.__all__
__all__ += fp16_lists.__all__
__all__ += fp16_utils.__all__
__all__ += bf16.__all__
24 changes: 24 additions & 0 deletions python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from . import amp_lists
from .amp_lists import *
from . import amp_utils
from .amp_utils import *

__all__ = []
__all__ += amp_lists.__all__
__all__ += amp_utils.__all__
97 changes: 97 additions & 0 deletions python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from ..fp16_lists import white_list as white_list_fp16, black_list as black_list_fp16,\
gray_list as gray_list_fp16, unsupported_fp16_list
arlesniak marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ["AutoMixedPrecisionListsBF16"]


class AutoMixedPrecisionListsBF16(object):
"""
AutoMixedPrecisionListsBF16 is a class for fp32/bf16 op types list. The lists are used for an
algorithm which determines op's execution mode (fp32 or bf16).It can update pre-defined
fp32 list and bf16 list according to users' custom fp32 bf16 lists.

Args:
custom_bf16_list (set): Users' custom bf16 list.
custom_fp32_list (set): Users' custom fp32 list.
custom_fp32_varnames (set): Users' custom fp32 variables' names.

Examples:
.. code-block:: python
import paddle
paddle.enable_static()
with paddle.static.amp.bf16_guard():
paddle.static.amp.AutoMixedPrecisionListsBF16(custom_fp32_list={'lstm'})
"""

def __init__(self,
custom_bf16_list=None,
custom_fp32_list=None,
custom_fp32_varnames=None):
self._custom_bf16_list = custom_bf16_list
self._custom_fp32_list = custom_fp32_list
self.bf16_list = copy.copy(bf16_list)
self.fp32_list = copy.copy(fp32_list)
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(unsupported_list)
self.fp32_varnames = copy.copy(custom_fp32_varnames)
self._update_list()

def _update_list(self):
"""
Update fp32 and bf16 list according to users' custom list.
"""
if self._custom_bf16_list and self._custom_fp32_list:
for op_name in self._custom_bf16_list:
if op_name in self._custom_fp32_list:
raise ValueError("Custom bf16 list overlap "
"custom fp32 list")
if self._custom_bf16_list:
for op_name in self._custom_bf16_list:
if op_name in self.fp32_list:
self.fp32_list.remove(op_name)
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.bf16_list.add(op_name)
if self._custom_fp32_list:
for op_name in self._custom_fp32_list:
if op_name in self.bf16_list:
self.bf16_list.remove(op_name)
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.fp32_list.add(op_name)
self.unsupported_list.add(op_name)


# always bf16
bf16_list = {'elementwise_add', }

# depends on the prev_op type
gray_list = {
'reshape2',
'lookup_table',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since #31558 is merged, please adjust the list

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for comment. Done.

}

unsupported_list = unsupported_fp16_list.copy().copy()
fp32_list = black_list_fp16.copy().copy()
fp32_list |= white_list_fp16
fp32_list |= gray_list_fp16

fp32_list -= bf16_list
fp32_list -= gray_list
unsupported_list -= bf16_list
unsupported_list -= gray_list
Loading