From 7ccf6b60306c700f59f5eb94d21abec323cd06eb Mon Sep 17 00:00:00 2001 From: arlesniak Date: Mon, 22 Mar 2021 07:43:33 +0100 Subject: [PATCH] [oneDNN] Initial bf16 amp integration (#31093) --- paddle/fluid/operators/cast_op.cc | 1 + paddle/fluid/operators/scale_op.cc | 2 + .../fluid/contrib/mixed_precision/__init__.py | 3 + .../contrib/mixed_precision/bf16/__init__.py | 24 ++ .../contrib/mixed_precision/bf16/amp_lists.py | 97 ++++++ .../contrib/mixed_precision/bf16/amp_utils.py | 296 ++++++++++++++++++ .../contrib/mixed_precision/fp16_lists.py | 2 +- .../fluid/contrib/tests/test_bf16_utils.py | 144 +++++++++ .../contrib/tests/test_model_cast_to_bf16.py | 138 ++++++++ python/paddle/fluid/data_feeder.py | 23 +- python/paddle/fluid/layers/nn.py | 16 +- .../fluid/tests/book/test_fit_a_line.py | 17 +- .../fluid/tests/book/test_word2vec_book.py | 29 +- .../paddle/fluid/tests/unittests/op_test.py | 17 +- python/paddle/static/amp/__init__.py | 3 + python/setup.py.in | 1 + tools/parallel_UT_rule.py | 1 + tools/static_mode_white_list.py | 1 + 18 files changed, 777 insertions(+), 38 deletions(-) create mode 100644 python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py create mode 100644 python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py create mode 100644 python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py create mode 100644 python/paddle/fluid/contrib/tests/test_bf16_utils.py create mode 100644 python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index c5cfa7a3bafce..40f4b969ec060 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -97,5 +97,6 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, + ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel); diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 281689d3bdaff..a9b1f299dab82 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -128,6 +128,8 @@ REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, REGISTER_OP_CPU_KERNEL( scale, ops::ScaleKernel, ops::ScaleKernel, + ops::ScaleKernel, ops::ScaleKernel, ops::ScaleKernel, ops::ScaleKernel, diff --git a/python/paddle/fluid/contrib/mixed_precision/__init__.py b/python/paddle/fluid/contrib/mixed_precision/__init__.py index a580ae5574c35..571b755b50d2a 100644 --- a/python/paddle/fluid/contrib/mixed_precision/__init__.py +++ b/python/paddle/fluid/contrib/mixed_precision/__init__.py @@ -20,7 +20,10 @@ from .fp16_lists import * from . import fp16_utils from .fp16_utils import * +from . import bf16 +from .bf16 import * __all__ = decorator.__all__ __all__ += fp16_lists.__all__ __all__ += fp16_utils.__all__ +__all__ += bf16.__all__ diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py b/python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py new file mode 100644 index 0000000000000..8c05bc4899cf7 --- /dev/null +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from . import amp_lists +from .amp_lists import * +from . import amp_utils +from .amp_utils import * + +__all__ = [] +__all__ += amp_lists.__all__ +__all__ += amp_utils.__all__ diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py new file mode 100644 index 0000000000000..81dc32d114b14 --- /dev/null +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from ..fp16_lists import white_list as white_list_fp16, black_list as black_list_fp16,\ + gray_list as gray_list_fp16, unsupported_fp16_list + +__all__ = ["AutoMixedPrecisionListsBF16"] + + +class AutoMixedPrecisionListsBF16(object): + """ + AutoMixedPrecisionListsBF16 is a class for fp32/bf16 op types list. The lists are used for an + algorithm which determines op's execution mode (fp32 or bf16).It can update pre-defined + fp32 list and bf16 list according to users' custom fp32 bf16 lists. + + Args: + custom_bf16_list (set): Users' custom bf16 list. + custom_fp32_list (set): Users' custom fp32 list. + custom_fp32_varnames (set): Users' custom fp32 variables' names. + + Examples: + .. code-block:: python + import paddle + paddle.enable_static() + with paddle.static.amp.bf16_guard(): + paddle.static.amp.AutoMixedPrecisionListsBF16(custom_fp32_list={'lstm'}) + """ + + def __init__(self, + custom_bf16_list=None, + custom_fp32_list=None, + custom_fp32_varnames=None): + self._custom_bf16_list = custom_bf16_list + self._custom_fp32_list = custom_fp32_list + self.bf16_list = copy.copy(bf16_list) + self.fp32_list = copy.copy(fp32_list) + self.gray_list = copy.copy(gray_list) + self.unsupported_list = copy.copy(unsupported_list) + self.fp32_varnames = copy.copy(custom_fp32_varnames) + self._update_list() + + def _update_list(self): + """ + Update fp32 and bf16 list according to users' custom list. + """ + if self._custom_bf16_list and self._custom_fp32_list: + for op_name in self._custom_bf16_list: + if op_name in self._custom_fp32_list: + raise ValueError("Custom bf16 list overlap " + "custom fp32 list") + if self._custom_bf16_list: + for op_name in self._custom_bf16_list: + if op_name in self.fp32_list: + self.fp32_list.remove(op_name) + elif op_name in self.gray_list: + self.gray_list.remove(op_name) + self.bf16_list.add(op_name) + if self._custom_fp32_list: + for op_name in self._custom_fp32_list: + if op_name in self.bf16_list: + self.bf16_list.remove(op_name) + elif op_name in self.gray_list: + self.gray_list.remove(op_name) + self.fp32_list.add(op_name) + self.unsupported_list.add(op_name) + + +# always bf16 +bf16_list = {'elementwise_add', } + +# depends on the prev_op type +gray_list = { + 'reshape2', + 'lookup_table', +} + +unsupported_list = unsupported_fp16_list.copy().copy() +fp32_list = black_list_fp16.copy().copy() +fp32_list |= white_list_fp16 +fp32_list |= gray_list_fp16 + +fp32_list -= bf16_list +fp32_list -= gray_list +unsupported_list -= bf16_list +unsupported_list -= gray_list diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py new file mode 100644 index 0000000000000..c2c01f88c7431 --- /dev/null +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py @@ -0,0 +1,296 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import struct + +from .... import core +from .... import framework +from ....log_helper import get_logger +from ....wrapped_decorator import signature_safe_contextmanager +from .amp_lists import AutoMixedPrecisionListsBF16 +from ..fp16_utils import find_true_prev_op, find_true_post_op, _rename_arg, find_op_index +import logging +import numpy as np + +__all__ = ["bf16_guard", "rewrite_program_bf16", "convert_float_to_uint16"] + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + +_valid_types = [ + core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS, + core.VarDesc.VarType.LOD_TENSOR_ARRAY +] + +_bf16_guard_pattern = "__use_bf16__" + + +def convert_float_to_uint16(in_list): + in_list = np.asarray(in_list) + out = np.vectorize( + lambda x: struct.unpack('> 16, + otypes=[np.uint16])(in_list.flat) + return np.reshape(out, in_list.shape) + + +def _dtype_to_str(dtype): + """ + Convert specific variable type to its corresponding string. + + Args: + dtype (VarType): Variable type. + """ + if dtype == core.VarDesc.VarType.BF16: + return 'bf16' + else: + return 'fp32' + + +def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): + """ + Insert cast op and rename args of input and output. + + Args: + block (Program): The block in which the operator is. + op (Operator): The operator to insert cast op. + idx (int): The index of current operator. + src_dtype (VarType): The input variable dtype of cast op. + dest_dtype (VarType): The output variable dtype of cast op. + + Returns: + num_cast_op (int): The number of cast ops that have been inserted. + """ + num_cast_ops = 0 + + for in_name in op.input_names: + if src_dtype == core.VarDesc.VarType.FP32 and op.type in [ + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + ]: + if in_name not in {'X', 'Z'}: + continue + for in_var_name in op.input(in_name): + in_var = block.var(in_var_name) + if in_var.type not in _valid_types or in_var.dtype == dest_dtype: + continue + if in_var.dtype == src_dtype: + cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype) + out_var = block.vars.get(cast_name) + if out_var is None or out_var.dtype != dest_dtype: + out_var = block.create_var( + name=cast_name, + dtype=dest_dtype, + persistable=False, + stop_gradient=in_var.stop_gradient) + + block._insert_op( + idx, + type="cast", + inputs={"X": in_var}, + outputs={"Out": out_var}, + attrs={ + "in_dtype": in_var.dtype, + "out_dtype": out_var.dtype + }) + num_cast_ops += 1 + _rename_arg(op, in_var.name, out_var.name) + else: + if op.has_attr('in_dtype'): + op._set_attr('in_dtype', dest_dtype) + if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.BF16: + for out_name in op.output_names: + if op.type in [ + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + ] and out_name != 'Y': + continue + for out_var_name in op.output(out_name): + out_var = block.var(out_var_name) + if out_var.type not in _valid_types: + continue + if out_var.dtype == core.VarDesc.VarType.FP32: + out_var.desc.set_dtype(core.VarDesc.VarType.BF16) + if op.has_attr('out_dtype'): + op._set_attr('out_dtype', core.VarDesc.VarType.BF16) + return num_cast_ops + + +def _is_in_fp32_varnames(op, amp_lists): + for in_name in op.input_arg_names: + if in_name in amp_lists.fp32_varnames: + return True + + for out_name in op.output_arg_names: + if out_name in amp_lists.fp32_varnames: + return True + + return False + + +def _need_keep_fp32(op, unsupported_op_list, use_bf16_guard): + if op.type in unsupported_op_list: + # the highest priority condition: If ops don't have bf16 computing kernels, + # they must be executed in fp32 calculation pattern. + return True + + # process ops about learning rate + in_out_arg_names = [] + in_out_arg_names.extend(list(op.input_arg_names)) + in_out_arg_names.extend(list(op.output_arg_names)) + for name in in_out_arg_names: + if "learning_rate" in name: + return True + + if use_bf16_guard: + if op.has_attr("op_namescope") and \ + (_bf16_guard_pattern in op.attr("op_namescope")): + # op in bf16 guard + return False + else: + # op not in bf16 guard + return True + else: + return False + + +@signature_safe_contextmanager +def bf16_guard(): + """ + As for the pure bf16 training, if users set `use_bf16_guard` to True, + only those ops created in the context manager `bf16_guard` will be + transformed as float16 type. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + import paddle.nn.functional as F + paddle.enable_static() + data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32') + conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3) + + with paddle.static.amp.bf16_guard(): + bn = paddle.static.nn.batch_norm(input=conv2d, act="relu") + pool = F.max_pool2d(bn, kernel_size=2, stride=2) + hidden = paddle.static.nn.fc(pool, size=10) + loss = paddle.mean(hidden) + """ + with framework.name_scope(prefix=_bf16_guard_pattern): + yield + + +def rewrite_program_bf16(main_prog, amp_lists=None, use_bf16_guard=False): + """ + Traverse all ops in current block and insert cast op according to + which set current op belongs to. + + 1. When an op belongs to the fp32 list, add it to fp32 set + 2. When an op belongs to the bf16 list, add it to bf16 set + 3. When an op belongs to the gray list. If one + of its inputs is the output of fp32 set op or fp32 list op, + add it to fp32 set. If all of its previous ops are not fp32 + op and one of its inputs is the output of bf16 set op or + bf16 list op, add it to bf16 set. + 4. When an op isn't in the lists, add it to fp32 op set. + 5. Add necessary cast ops to make sure that fp32 set op will be + computed in fp32 mode, while bf16 set op will be computed in + bf16 mode. + + Args: + main_prog (Program): The main program for training. + """ + if amp_lists is None: + amp_lists = AutoMixedPrecisionListsBF16() + block = main_prog.global_block() + ops = block.ops + bf16_op_set = set() + fp32_op_set = set() + for op in ops: + + # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder, + # we don't need to handle reader op and the input of 'create_py_reader' is not + # in block, which may result in errors. + # See GeneratorLoader._init_non_iterable() for details. + if op.type == 'create_py_reader' or op.type == 'read': + continue + + if amp_lists.fp32_varnames is not None and _is_in_fp32_varnames( + op, amp_lists): + fp32_op_set.add(op) + continue + + if op.type in amp_lists.fp32_list or _need_keep_fp32( + op, amp_lists.unsupported_list, use_bf16_guard): + fp32_op_set.add(op) + elif op.type in amp_lists.bf16_list: + bf16_op_set.add(op) + elif op.type in amp_lists.gray_list: + is_fp32_op = False + is_bf16_op = False + for in_name in op.input_names: + # if this op has inputs + if in_name: + for in_var_name in op.input(in_name): + in_var = block.var(in_var_name) + # this in_var isn't the output of other op + if in_var.op is None: + continue + elif in_var.op is op: + prev_op = find_true_prev_op(ops, op, in_var_name) + if prev_op is None: + continue + else: + prev_op = in_var.op + # if it's one of inputs + if prev_op in fp32_op_set or \ + prev_op.type in amp_lists.fp32_list: + is_fp32_op = True + elif prev_op in bf16_op_set or \ + prev_op.type in amp_lists.bf16_list: + is_bf16_op = True + if is_fp32_op: + fp32_op_set.add(op) + elif is_bf16_op: + bf16_op_set.add(op) + else: + pass + else: + # For numerical safe, we apply fp32 computation on ops that + # are not determined which list they should stay. + fp32_op_set.add(op) + + idx = 0 + while idx < len(ops): + op = ops[idx] + num_cast_ops = 0 + if op in fp32_op_set: + num_cast_ops = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.BF16, + core.VarDesc.VarType.FP32) + elif op in bf16_op_set: + if op.has_attr('use_mkldnn'): + op._set_attr('use_mkldnn', True) + op._set_attr('mkldnn_data_type', 'bfloat16') + elif op.has_attr('dtype') and op.attr( + 'dtype') == core.VarDesc.VarType.FP32: + op._set_attr('dtype', core.VarDesc.VarType.BF16) + + num_cast_ops = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.BF16) + else: + pass + + idx += num_cast_ops + 1 diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index c88ae2d9cbf60..6a524af4ee240 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -69,7 +69,7 @@ def _update_list(self): self.unsupported_list.add(op_name) -# The three sets listed below are changed dynamiclly. They don't contain all +# The three sets listed below are changed dynamiclly. They don't contain all # paddle ops currently. # The set of ops that support fp16 calculation and are considered numerically- diff --git a/python/paddle/fluid/contrib/tests/test_bf16_utils.py b/python/paddle/fluid/contrib/tests/test_bf16_utils.py new file mode 100644 index 0000000000000..faf2307f8147b --- /dev/null +++ b/python/paddle/fluid/contrib/tests/test_bf16_utils.py @@ -0,0 +1,144 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import unittest +import paddle.fluid as fluid +import paddle.fluid.contrib.mixed_precision as amp +from paddle.fluid import core +import paddle + +paddle.enable_static() + + +class AMPTest(unittest.TestCase): + def setUp(self): + self.bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) + self.fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) + self.gray_list = copy.copy(amp.bf16.amp_lists.gray_list) + self.amp_lists_ = None + + def tearDown(self): + self.assertEqual(self.amp_lists_.bf16_list, self.bf16_list) + self.assertEqual(self.amp_lists_.fp32_list, self.fp32_list) + self.assertEqual(self.amp_lists_.gray_list, self.gray_list) + + def test_amp_lists(self): + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16() + + def test_amp_lists_1(self): + # 1. w={'exp}, b=None + self.bf16_list.add('exp') + self.fp32_list.remove('exp') + + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'exp'}) + + def test_amp_lists_2(self): + # 2. w={'tanh'}, b=None + self.fp32_list.remove('tanh') + self.bf16_list.add('tanh') + + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'tanh'}) + + def test_amp_lists_3(self): + # 3. w={'lstm'}, b=None + self.bf16_list.add('lstm') + + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'lstm'}) + + def test_amp_lists_4(self): + # 4. w=None, b={'elementwise_add'} + self.bf16_list.remove('elementwise_add') + self.fp32_list.add('elementwise_add') + + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( + custom_fp32_list={'elementwise_add'}) + + def test_amp_lists_5(self): + # 5. w=None, b={'elementwise_add'} + self.fp32_list.add('elementwise_add') + self.bf16_list.remove('elementwise_add') + + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( + custom_fp32_list={'elementwise_add'}) + + def test_amp_lists_6(self): + # 6. w=None, b={'lstm'} + self.fp32_list.add('lstm') + + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( + custom_fp32_list={'lstm'}) + + def test_amp_lists_7(self): + self.fp32_list.add('reshape2') + self.gray_list.remove('reshape2') + + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( + custom_fp32_list={'reshape2'}) + + def test_amp_list_8(self): + self.bf16_list.add('reshape2') + self.gray_list.remove('reshape2') + + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( + custom_bf16_list={'reshape2'}) + + +class AMPTest2(unittest.TestCase): + def test_amp_lists_(self): + # 7. w={'lstm'} b={'lstm'} + # raise ValueError + self.assertRaises(ValueError, amp.AutoMixedPrecisionListsBF16, + {'lstm'}, {'lstm'}) + + def test_find_op_index(self): + block = fluid.default_main_program().global_block() + op_desc = core.OpDesc() + idx = amp.bf16.amp_utils.find_op_index(block.desc, op_desc) + assert (idx == -1) + + def test_is_in_fp32_varnames(self): + block = fluid.default_main_program().global_block() + + var1 = block.create_var(name="X", shape=[3], dtype='float32') + var2 = block.create_var(name="Y", shape=[3], dtype='float32') + var3 = block.create_var(name="Z", shape=[3], dtype='float32') + op1 = block.append_op( + type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]}) + op2 = block.append_op( + type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]}) + amp_lists_1 = amp.AutoMixedPrecisionListsBF16( + custom_fp32_varnames={'X'}) + assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_1) + amp_lists_2 = amp.AutoMixedPrecisionListsBF16( + custom_fp32_varnames={'Y'}) + assert amp.bf16.amp_utils._is_in_fp32_varnames(op2, amp_lists_2) + assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_2) + + def test_find_true_post_op(self): + + block = fluid.default_main_program().global_block() + + var1 = block.create_var(name="X", shape=[3], dtype='float32') + var2 = block.create_var(name="Y", shape=[3], dtype='float32') + var3 = block.create_var(name="Z", shape=[3], dtype='float32') + op1 = block.append_op( + type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]}) + op2 = block.append_op( + type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]}) + res = amp.bf16.amp_utils.find_true_post_op(block.ops, op1, "Y") + assert (res == [op2]) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py new file mode 100644 index 0000000000000..40ddcf2e66b75 --- /dev/null +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import contextlib +import unittest +import numpy as np +import paddle.fluid.layers as layers +import paddle.static.amp as amp +from paddle.fluid import core + +paddle.enable_static() + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestModelCastBF16(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.seed = 111 + + @classmethod + def tearDownClass(cls): + pass + + @contextlib.contextmanager + def static_graph(self): + with self.scope_prog_guard(): + paddle.seed(self.seed) + paddle.framework.random._manual_program_seed(self.seed) + yield + + @contextlib.contextmanager + def scope_prog_guard(self): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + yield + + def get_static_graph_result(self, feed, fetch_list, amp_fun, + with_lod=False): + exe = fluid.Executor(core.CPUPlace()) + exe.run(fluid.default_startup_program()) + prog = fluid.default_main_program() + if amp_fun is not None: + amp_fun(prog) + return exe.run(prog, + feed=feed, + fetch_list=fetch_list, + return_numpy=(not with_lod)) + + def test_graph_rewrite(self): + size = 3 + n = np.ones([size, size], dtype='float32') * 3.2 + nn = np.ones([size, size], dtype='float32') * -2.7 + + n_bf16 = amp.convert_float_to_uint16(n) + nn_bf16 = amp.convert_float_to_uint16(nn) + + with self.static_graph(): + t_bf16 = layers.data( + name='t_bf16', shape=[size, size], dtype=np.uint16) + tt_bf16 = layers.data( + name='tt_bf16', shape=[size, size], dtype=np.uint16) + t = layers.data(name='t', shape=[size, size], dtype='float32') + tt = layers.data(name='tt', shape=[size, size], dtype='float32') + + ret = layers.elementwise_add(t, tt) + ret = layers.elementwise_mul(ret, t) + ret = layers.reshape(ret, [0, 0]) + + with amp.bf16_guard(): + ret_bf16 = layers.elementwise_add(t_bf16, tt_bf16) + ret_bf16 = layers.elementwise_mul(ret_bf16, t_bf16) + ret_bf16 = layers.reshape(ret_bf16, [0, 0]) + + with amp.bf16_guard(): + ret_fp32bf16 = layers.elementwise_add(t, tt) + ret_fp32bf16 = layers.elementwise_mul(ret_fp32bf16, t) + ret_fp32bf16 = layers.reshape(ret_fp32bf16, [0, 0]) + + static_ret_bf16, static_ret, ret_fp32bf16 = self.get_static_graph_result( + feed={ + 't': n, + 'tt': nn, + 't_bf16': n_bf16, + 'tt_bf16': nn_bf16, + }, + fetch_list=[ret_bf16, ret, ret_fp32bf16], + amp_fun=lambda prog: amp.rewrite_program_bf16(prog, use_bf16_guard=True)) + + self.assertTrue(np.allclose(static_ret_bf16, static_ret, 1e-2)) + self.assertTrue(np.allclose(static_ret_bf16, ret_fp32bf16, 1e-2)) + + with self.static_graph(): + t = layers.data(name='t', shape=[size, size], dtype='float32') + tt = layers.data(name='tt', shape=[size, size], dtype='float32') + + with amp.bf16_guard(): + ret = layers.elementwise_add(t, tt) + ret = layers.reshape(ret, [0, 0], act='elu') + ret = layers.elementwise_mul(ret, t) + ret = layers.elementwise_add(ret, tt) + + static_ret_bf16 = \ + self.get_static_graph_result( + feed={'t': n, 'tt': nn}, + fetch_list=[ret], + amp_fun=lambda prog: amp.rewrite_program_bf16( + prog, + amp.AutoMixedPrecisionListsBF16( + custom_fp32_varnames={'elementwise_add_0.tmp_0'}), + use_bf16_guard=True + ) + ) + self.assertTrue( + static_ret_bf16, np.ones( + [size, size], dtype='float32') * -1.1) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index b2db00296bf95..52be7493cf229 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -29,6 +29,7 @@ _PADDLE_DTYPE_2_NUMPY_DTYPE = { core.VarDesc.VarType.BOOL: 'bool', core.VarDesc.VarType.FP16: 'float16', + core.VarDesc.VarType.BF16: 'uint16', core.VarDesc.VarType.FP32: 'float32', core.VarDesc.VarType.FP64: 'float64', core.VarDesc.VarType.INT8: 'int8', @@ -47,16 +48,18 @@ def convert_dtype(dtype): return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] elif isinstance(dtype, type): if dtype in [ - np.bool, np.float16, np.float32, np.float64, np.int8, np.int16, - np.int32, np.int64, np.uint8, np.complex64, np.complex128 + np.bool, np.float16, np.uint16, np.float32, np.float64, np.int8, + np.int16, np.int32, np.int64, np.uint8, np.complex64, + np.complex128 ]: return dtype.__name__ else: if dtype in [ - 'bool', 'float16', 'float32', 'float64', 'int8', 'int16', - 'int32', 'int64', 'uint8', 'complex64', 'complex128', u'bool', - u'float16', u'float32', u'float64', u'int8', u'int16', u'int32', - u'int64', u'uint8', u'complex64', u'complex128' + 'bool', 'float16', 'uint16', 'float32', 'float64', 'int8', + 'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', + u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8', + u'int16', u'int32', u'int64', u'uint8', u'complex64', + u'complex128' ]: # this code is a little bit dangerous, since error could happen # when casting no-ascii code to str in python2. @@ -66,7 +69,7 @@ def convert_dtype(dtype): return str(dtype) raise TypeError( - "dtype must be any of [bool, float16, float32, float64, int8, int16, " + "dtype must be any of [bool, float16, uint16, float32, float64, int8, int16, " "int32, int64, uint8, complex64, complex128], but received %s" % dtype) @@ -123,6 +126,12 @@ def check_dtype(input_dtype, warnings.warn( "The data type of '%s' in %s only support float16 in GPU now. %s" % (input_name, op_name, extra_message)) + if convert_dtype(input_dtype) in ['uint16'] and op_name not in [ + 'reshape', 'lookup_table', 'scale' + ]: + warnings.warn( + "The data type of '%s' in %s only support bfloat16 in OneDNN now. %s" + % (input_name, op_name, extra_message)) if convert_dtype(input_dtype) not in expected_dtype: raise TypeError( "The data type of '%s' in %s must be %s, but received %s. %s" % diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index fa8df14c8669b..00d1db19fc2f5 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6137,9 +6137,9 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): return dygraph_utils._append_activation_in_dygraph(out, act) - check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', - 'bool'], 'reshape') + check_variable_and_dtype(x, 'x', [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'uint16' + ], 'reshape') check_type(shape, 'shape', (list, tuple, Variable), 'reshape') check_type(actual_shape, 'actual_shape', (Variable, type(None)), 'reshape') @@ -11354,9 +11354,11 @@ def _elementwise_op(helper): assert x is not None, 'x cannot be None in {}'.format(op_type) assert y is not None, 'y cannot be None in {}'.format(op_type) check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], op_type) + x, 'x', ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], + op_type) check_variable_and_dtype( - y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64'], op_type) + y, 'y', ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], + op_type) axis = helper.kwargs.get('axis', -1) use_mkldnn = helper.kwargs.get('use_mkldnn', False) @@ -11428,8 +11430,8 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): return dygraph_utils._append_activation_in_dygraph(out) check_variable_and_dtype(x, "x", [ - 'float16', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', - 'uint8' + 'float16', 'uint16', 'float32', 'float64', 'int8', 'int16', 'int32', + 'int64', 'uint8' ], "scale") inputs = {'X': [x]} attrs = { diff --git a/python/paddle/fluid/tests/book/test_fit_a_line.py b/python/paddle/fluid/tests/book/test_fit_a_line.py index 9a2cc4ab1a1b9..df43d9366ff78 100644 --- a/python/paddle/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/fluid/tests/book/test_fit_a_line.py @@ -26,7 +26,7 @@ paddle.enable_static() -def train(use_cuda, save_dirname, is_local): +def train(use_cuda, save_dirname, is_local, use_bf16): x = fluid.layers.data(name='x', shape=[13], dtype='float32') y_predict = fluid.layers.fc(input=x, size=1, act=None) @@ -37,6 +37,8 @@ def train(use_cuda, save_dirname, is_local): avg_cost = fluid.layers.mean(cost) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) + if use_bf16: + paddle.static.amp.rewrite_program_bf16(fluid.default_main_program()) sgd_optimizer.minimize(avg_cost) BATCH_SIZE = 20 @@ -133,14 +135,17 @@ def infer(use_cuda, save_dirname=None): print("ground truth: ", test_label) -def main(use_cuda, is_local=True): +def main(use_cuda, is_local=True, use_bf16=False): if use_cuda and not fluid.core.is_compiled_with_cuda(): return + if use_bf16 and not fluid.core.is_compiled_with_mkldnn(): + return + # Directory for saving the trained model save_dirname = "fit_a_line.inference.model" - train(use_cuda, save_dirname, is_local) + train(use_cuda, save_dirname, is_local, use_bf16) infer(use_cuda, save_dirname) @@ -153,6 +158,12 @@ def test_cuda(self): with self.program_scope_guard(): main(use_cuda=True) + @unittest.skipIf(not fluid.core.supports_bfloat16(), + "place does not support BF16 evaluation") + def test_bf16(self): + with self.program_scope_guard(): + main(use_cuda=False, use_bf16=True) + @contextlib.contextmanager def program_scope_guard(self): prog = fluid.Program() diff --git a/python/paddle/fluid/tests/book/test_word2vec_book.py b/python/paddle/fluid/tests/book/test_word2vec_book.py index e33b1cc514aa6..ad7550fa9dd96 100644 --- a/python/paddle/fluid/tests/book/test_word2vec_book.py +++ b/python/paddle/fluid/tests/book/test_word2vec_book.py @@ -39,7 +39,12 @@ def get_place(target): format(target)) -def train(target, is_sparse, is_parallel, save_dirname, is_local=True): +def train(target, + is_sparse, + is_parallel, + save_dirname, + is_local=True, + use_bf16=False): PASS_NUM = 100 EMBED_SIZE = 32 HIDDEN_SIZE = 256 @@ -101,6 +106,8 @@ def __network__(words): raise NotImplementedError() sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) + if use_bf16: + paddle.static.amp.rewrite_program_bf16(fluid.default_main_program()) sgd_optimizer.minimize(avg_cost) train_reader = paddle.batch( @@ -239,12 +246,15 @@ def to_infer_tensor(lod_tensor): assert np.isclose(a, b, rtol=5e-5), "a: {}, b: {}".format(a, b) -def main(target, is_sparse, is_parallel): +def main(target, is_sparse, is_parallel, use_bf16): if target == "cuda" and not fluid.core.is_compiled_with_cuda(): return if target == "xpu" and not fluid.core.is_compiled_with_xpu(): return + if use_bf16 and not fluid.core.is_compiled_with_mkldnn(): + return + if not is_parallel: save_dirname = "word2vec.inference.model" else: @@ -255,7 +265,7 @@ def main(target, is_sparse, is_parallel): # so only inference is turned on. train("cpu", is_sparse, is_parallel, save_dirname) else: - train(target, is_sparse, is_parallel, save_dirname) + train(target, is_sparse, is_parallel, save_dirname, use_bf16=use_bf16) infer(target, save_dirname) @@ -268,10 +278,11 @@ class W2VTest(unittest.TestCase): pass -def inject_test_method(target, is_sparse, is_parallel): - fn_name = "test_{0}_{1}_{2}".format(target, "sparse" - if is_sparse else "dense", "parallel" - if is_parallel else "normal") +def inject_test_method(target, is_sparse, is_parallel, use_bf16=False): + fn_name = "test_{0}_{1}_{2}{3}".format(target, "sparse" + if is_sparse else "dense", "parallel" + if is_parallel else "normal", "_bf16" + if use_bf16 else "") def __impl__(*args, **kwargs): prog = fluid.Program() @@ -279,8 +290,7 @@ def __impl__(*args, **kwargs): scope = fluid.core.Scope() with fluid.scope_guard(scope): with fluid.program_guard(prog, startup_prog): - main( - target=target, is_sparse=is_sparse, is_parallel=is_parallel) + main(target, is_sparse, is_parallel, use_bf16) if (not fluid.core.is_compiled_with_cuda() or target == "cuda") and is_sparse: @@ -297,6 +307,7 @@ def __impl__(*args, **kwargs): for is_sparse in (False, True): for is_parallel in (False, ): inject_test_method(target, is_sparse, is_parallel) +inject_test_method("cpu", False, False, use_bf16=True) if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 939e2ac0f59fd..dff96a8cbc3c4 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -244,17 +244,12 @@ def convert_float_to_uint16(float_list, data_format="NCHW"): return new_output -def copy_bits_from_uint16_to_float(i): - i = np.uint32(i) << 16 - return struct.unpack('