Skip to content

Commit

Permalink
[oneDNN] Initial bf16 amp integration (#31093)
Browse files Browse the repository at this point in the history
  • Loading branch information
arlesniak authored Mar 22, 2021
1 parent a501a7b commit 7ccf6b6
Show file tree
Hide file tree
Showing 18 changed files with 777 additions and 38 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/operators/cast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,6 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
ops::CastOpKernel<CPU, paddle::platform::complex64>,
ops::CastOpKernel<CPU, paddle::platform::complex128>);
2 changes: 2 additions & 0 deletions paddle/fluid/operators/scale_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker,
REGISTER_OP_CPU_KERNEL(
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int16_t>,
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/contrib/mixed_precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from .fp16_lists import *
from . import fp16_utils
from .fp16_utils import *
from . import bf16
from .bf16 import *

__all__ = decorator.__all__
__all__ += fp16_lists.__all__
__all__ += fp16_utils.__all__
__all__ += bf16.__all__
24 changes: 24 additions & 0 deletions python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from . import amp_lists
from .amp_lists import *
from . import amp_utils
from .amp_utils import *

__all__ = []
__all__ += amp_lists.__all__
__all__ += amp_utils.__all__
97 changes: 97 additions & 0 deletions python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from ..fp16_lists import white_list as white_list_fp16, black_list as black_list_fp16,\
gray_list as gray_list_fp16, unsupported_fp16_list

__all__ = ["AutoMixedPrecisionListsBF16"]


class AutoMixedPrecisionListsBF16(object):
"""
AutoMixedPrecisionListsBF16 is a class for fp32/bf16 op types list. The lists are used for an
algorithm which determines op's execution mode (fp32 or bf16).It can update pre-defined
fp32 list and bf16 list according to users' custom fp32 bf16 lists.
Args:
custom_bf16_list (set): Users' custom bf16 list.
custom_fp32_list (set): Users' custom fp32 list.
custom_fp32_varnames (set): Users' custom fp32 variables' names.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
with paddle.static.amp.bf16_guard():
paddle.static.amp.AutoMixedPrecisionListsBF16(custom_fp32_list={'lstm'})
"""

def __init__(self,
custom_bf16_list=None,
custom_fp32_list=None,
custom_fp32_varnames=None):
self._custom_bf16_list = custom_bf16_list
self._custom_fp32_list = custom_fp32_list
self.bf16_list = copy.copy(bf16_list)
self.fp32_list = copy.copy(fp32_list)
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(unsupported_list)
self.fp32_varnames = copy.copy(custom_fp32_varnames)
self._update_list()

def _update_list(self):
"""
Update fp32 and bf16 list according to users' custom list.
"""
if self._custom_bf16_list and self._custom_fp32_list:
for op_name in self._custom_bf16_list:
if op_name in self._custom_fp32_list:
raise ValueError("Custom bf16 list overlap "
"custom fp32 list")
if self._custom_bf16_list:
for op_name in self._custom_bf16_list:
if op_name in self.fp32_list:
self.fp32_list.remove(op_name)
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.bf16_list.add(op_name)
if self._custom_fp32_list:
for op_name in self._custom_fp32_list:
if op_name in self.bf16_list:
self.bf16_list.remove(op_name)
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.fp32_list.add(op_name)
self.unsupported_list.add(op_name)


# always bf16
bf16_list = {'elementwise_add', }

# depends on the prev_op type
gray_list = {
'reshape2',
'lookup_table',
}

unsupported_list = unsupported_fp16_list.copy().copy()
fp32_list = black_list_fp16.copy().copy()
fp32_list |= white_list_fp16
fp32_list |= gray_list_fp16

fp32_list -= bf16_list
fp32_list -= gray_list
unsupported_list -= bf16_list
unsupported_list -= gray_list
Loading

0 comments on commit 7ccf6b6

Please sign in to comment.