From 776eca313a8a0ee3e8b055071bdcda5042ca265c Mon Sep 17 00:00:00 2001 From: arlesniak Date: Mon, 22 Feb 2021 00:22:31 +0100 Subject: [PATCH 01/33] Initial bf16 amp integration --- paddle/fluid/operators/cast_op.cc | 1 + .../fluid/contrib/mixed_precision/__init__.py | 9 + .../fluid/contrib/mixed_precision/amp_nn.py | 3 + .../contrib/mixed_precision/bf16_lists.py | 44 ++ .../contrib/mixed_precision/bf16_utils.py | 631 ++++++++++++++++++ .../contrib/mixed_precision/decorator.py | 4 +- .../contrib/mixed_precision/decorator_bf16.py | 526 +++++++++++++++ .../contrib/mixed_precision/fp16_lists.py | 37 +- .../contrib/mixed_precision/fp16_utils.py | 4 +- .../tests/test_image_classification_fp16.py | 25 +- .../contrib/tests/test_model_cast_to_bf16.py | 101 +++ python/paddle/fluid/data_feeder.py | 21 +- python/paddle/fluid/layers/nn.py | 6 +- .../paddle/fluid/tests/unittests/op_test.py | 17 +- 14 files changed, 1386 insertions(+), 43 deletions(-) create mode 100644 python/paddle/fluid/contrib/mixed_precision/bf16_lists.py create mode 100644 python/paddle/fluid/contrib/mixed_precision/bf16_utils.py create mode 100644 python/paddle/fluid/contrib/mixed_precision/decorator_bf16.py create mode 100644 python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index c5cfa7a3bafce..40f4b969ec060 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -97,5 +97,6 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, + ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel); diff --git a/python/paddle/fluid/contrib/mixed_precision/__init__.py b/python/paddle/fluid/contrib/mixed_precision/__init__.py index a580ae5574c35..baab9e167c069 100644 --- a/python/paddle/fluid/contrib/mixed_precision/__init__.py +++ b/python/paddle/fluid/contrib/mixed_precision/__init__.py @@ -16,11 +16,20 @@ from . import decorator from .decorator import * +from . import decorator_bf16 +from .decorator_bf16 import * from . import fp16_lists from .fp16_lists import * +from . import bf16_lists +from .bf16_lists import * from . import fp16_utils from .fp16_utils import * +from . import bf16_utils +from .bf16_utils import * __all__ = decorator.__all__ +__all__ = decorator_bf16.__all__ __all__ += fp16_lists.__all__ +__all__ += bf16_lists.__all__ __all__ += fp16_utils.__all__ +__all__ += bf16_utils.__all__ diff --git a/python/paddle/fluid/contrib/mixed_precision/amp_nn.py b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py index 3bfc078971d7a..16b117bf0f4e7 100644 --- a/python/paddle/fluid/contrib/mixed_precision/amp_nn.py +++ b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py @@ -97,6 +97,9 @@ def update_loss_scaling(x, if e.dtype == core.VarDesc.VarType.FP16: assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \ "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." + elif e.dtype == core.VarDesc.VarType.BF16: + assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \ + "The dtype of prev_loss_scaling should be float32 when the dtype of x is bfloat16." else: assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x." diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16_lists.py b/python/paddle/fluid/contrib/mixed_precision/bf16_lists.py new file mode 100644 index 0000000000000..b41cb10622785 --- /dev/null +++ b/python/paddle/fluid/contrib/mixed_precision/bf16_lists.py @@ -0,0 +1,44 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .fp16_lists import AutoMixedPrecisionLists, \ + white_list as white_list_fp16, black_list as black_list_fp16, \ + gray_list as gray_list_fp16, unsupported_fp16_list + +__all__ = ["AutoMixedPrecisionListsBF16"] + + +class AutoMixedPrecisionListsBF16(AutoMixedPrecisionLists): + def __init__(self, + custom_white_list=None, + custom_black_list=None, + custom_black_varnames=None): + super(AutoMixedPrecisionListsBF16, self).__init__( + white_list, + black_list, + gray_list, + unsupported_list, + custom_white_list=custom_white_list, + custom_black_list=custom_black_list, + custom_black_varnames=custom_black_varnames) + + +white_list = {'elementwise_add'} +black_list = black_list_fp16.copy().copy() +black_list.update(white_list_fp16) +black_list.update(gray_list_fp16) +gray_list = set() +unsupported_list = unsupported_fp16_list + +CustomOpListsBF16 = AutoMixedPrecisionListsBF16 diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py new file mode 100644 index 0000000000000..e8e8372b066f2 --- /dev/null +++ b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py @@ -0,0 +1,631 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from ... import core +from ... import framework +from ... import layers +from ... import global_scope +from ...log_helper import get_logger +from ...wrapped_decorator import signature_safe_contextmanager +from .bf16_lists import AutoMixedPrecisionListsBF16 +import collections +import logging +import numpy as np + +__all__ = ["bf16_guard", "cast_model_to_bf16", "cast_parameters_to_bf16"] + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + +_valid_types = [ + core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS, + core.VarDesc.VarType.LOD_TENSOR_ARRAY +] + +_bf16_guard_pattern = "__use_bf16__" + + +def _rename_arg(op, old_name, new_name): + """ + If an op has old_name input and output, rename these input + args new_name. + + Args: + op (Operator): Current operator. + old_name (str): The old name of input args. + new_name (str): The new name of input args. + """ + op_desc = op.desc + if isinstance(op_desc, tuple): + op_desc = op_desc[0] + op_desc._rename_input(old_name, new_name) + op_desc._rename_output(old_name, new_name) + + +def _rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops): + for block in program.blocks: + ops = block.ops + block_id = block.idx + for op in ops: + if op not in origin_ops or op in keep_fp32_ops: + continue + for name in op.input_arg_names: + if name in op_var_rename_map[block_id]: + op._rename_input(name, op_var_rename_map[block_id][name]) + + +def _dtype_to_str(dtype): + """ + Convert specific variable type to its corresponding string. + + Args: + dtype (VarType): Variable type. + """ + if dtype == core.VarDesc.VarType.BF16: + return 'bf16' + else: + return 'fp32' + + +def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): + """ + Insert cast op and rename args of input and output. + + Args: + block (Program): The block in which the operator is. + op (Operator): The operator to insert cast op. + idx (int): The index of current operator. + src_dtype (VarType): The input variable dtype of cast op. + dest_dtype (VarType): The output variable dtype of cast op. + + Returns: + num_cast_op (int): The number of cast ops that have been inserted. + """ + num_cast_ops = 0 + + for in_name in op.input_names: + if src_dtype == core.VarDesc.VarType.FP32 and op.type in [ + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + ]: + if in_name not in {'X', 'Z'}: + continue + for in_var_name in op.input(in_name): + in_var = block.var(in_var_name) + if in_var.type not in _valid_types or in_var.dtype == dest_dtype: + continue + if in_var.dtype == src_dtype: + cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype) + out_var = block.vars.get(cast_name) + if out_var is None or out_var.dtype != dest_dtype: + out_var = block.create_var( + name=cast_name, + dtype=dest_dtype, + persistable=False, + stop_gradient=in_var.stop_gradient) + + block._insert_op( + idx, + type="cast", + inputs={"X": in_var}, + outputs={"Out": out_var}, + attrs={ + "in_dtype": in_var.dtype, + "out_dtype": out_var.dtype + }) + num_cast_ops += 1 + _rename_arg(op, in_var.name, out_var.name) + else: + if op.has_attr('in_dtype'): + op._set_attr('in_dtype', dest_dtype) + if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.BF16: + for out_name in op.output_names: + if op.type in [ + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + ] and out_name != 'Y': + continue + for out_var_name in op.output(out_name): + out_var = block.var(out_var_name) + if out_var.type not in _valid_types: + continue + if out_var.dtype == core.VarDesc.VarType.FP32: + out_var.desc.set_dtype(core.VarDesc.VarType.BF16) + if op.has_attr('out_dtype'): + op._set_attr('out_dtype', core.VarDesc.VarType.BF16) + return num_cast_ops + + +def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, + op_var_rename_map): + num_cast_ops = 0 + + target_var = block.var(target_name) + if target_var.type not in _valid_types or target_var.dtype == dest_dtype: + return num_cast_ops + + assert target_var.dtype == src_dtype, \ + "The real dtype({}) is not equal to the src dtype({})".format(_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype)) + + cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype) + cast_var = block.vars.get(cast_name) + if cast_var is None or cast_var.dtype != dest_dtype: + cast_var = block.create_var( + name=cast_name, + dtype=dest_dtype, + persistable=False, + stop_gradient=target_var.stop_gradient) + block._insert_op( + idx, + type="cast", + inputs={"X": target_var}, + outputs={"Out": cast_var}, + attrs={"in_dtype": target_var.dtype, + "out_dtype": cast_var.dtype}) + num_cast_ops += 1 + op_var_rename_map[block.idx][target_var.name] = cast_var.name + + return num_cast_ops + + +def find_true_prev_op(ops, cur_op, var_name): + """ + Find the true prev op that outputs var_name variable. + + Args: + ops (list): A list of ops. + cur_op (Operator): Current operator which has var_name variable. + var_name (string): Variable name. + """ + prev_op = [] + for op in ops: + if op == cur_op: + break + for out_name in op.output_names: + for out_var_name in op.output(out_name): + if out_var_name == var_name: + prev_op.append(op) + if prev_op: + if not len(prev_op) == 1: + raise ValueError("There must be only one previous op " + "that outputs {0} variable".format(var_name)) + else: + return prev_op[0] + return None + + +def find_true_post_op(ops, cur_op, var_name): + """ + if there are post ops, return them, if there is no post op, + return None instead. + Args: + ops (list): A list of ops. + cur_op (Operator): Current operator which has var_name variable. + var_name (string): Variable name. + """ + post_op = [] + for idx, op in enumerate(ops): + if op == cur_op: + break + + for i in range(idx + 1, len(ops)): + op = ops[i] + for in_name in op.input_names: + for in_var_name in op.input(in_name): + if in_var_name == var_name: + post_op.append(op) + + return post_op + + +def find_op_index(block_desc, cur_op_desc): + """ + """ + for idx in range(block_desc.op_size()): + if cur_op_desc == block_desc.op(idx): + return idx + return -1 + + +def _is_in_black_varnames(op, amp_lists): + for in_name in op.input_arg_names: + if in_name in amp_lists.black_varnames: + return True + + for out_name in op.output_arg_names: + if out_name in amp_lists.black_varnames: + return True + + return False + + +def _need_keep_fp32(op, unsupported_op_list, use_bf16_guard): + if op.type in unsupported_op_list: + # the highest priority condition: If ops don't have bf16 computing kernels, + # they must be executed in fp32 calculation pattern. + return True + + # process ops about learning rate + in_out_arg_names = [] + in_out_arg_names.extend(list(op.input_arg_names)) + in_out_arg_names.extend(list(op.output_arg_names)) + for name in in_out_arg_names: + if "learning_rate" in name: + return True + + if use_bf16_guard: + if op.has_attr("op_namescope") and \ + (_bf16_guard_pattern in op.attr("op_namescope")): + # op in bf16 guard + return False + else: + # op not in bf16 guard + return True + else: + return False + + +@signature_safe_contextmanager +def bf16_guard(): + """ + As for the pure bf16 training, if users set `use_bf16_guard` to True, + only those ops created in the context manager `bf16_guard` will be + transformed as float16 type. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + import paddle.nn.functional as F + paddle.enable_static() + data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32') + conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3) + + with paddle.static.amp.bf16_guard(): + bn = paddle.static.nn.batch_norm(input=conv2d, act="relu") + pool = F.max_pool2d(bn, kernel_size=2, stride=2) + hidden = paddle.static.nn.fc(pool, size=10) + loss = paddle.mean(hidden) + """ + with framework.name_scope(prefix=_bf16_guard_pattern): + yield + + +def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True): + """ + Traverse all ops in the whole model and set their inputs and outputs + to the bf16 data type. This function will do some special process for + the batch normalization, which keeps the computational process of + batchnorms in FP32. + Args: + program (Program): The used program. + amp_lists (AutoMixedPrecisionListsBF16): An AutoMixedPrecisionListsBF16 object. + use_bf16_guard(bool): Determine whether to use `bf16_guard` when + constructing the program. Default True. + """ + + if amp_lists is None: + amp_lists = AutoMixedPrecisionListsBF16() + global_block = program.global_block() + keep_fp32_ops = set() + to_bf16_var_names = set() + to_bf16_pre_cast_ops = set() + origin_ops = [] + for block in program.blocks: + origin_ops.extend(block.ops) + + for block in program.blocks: + ops = block.ops + for op in ops: + if op.type == 'create_py_reader' or op.type == 'read': + continue + if _need_keep_fp32(op, amp_lists.unsupported_list, use_bf16_guard): + keep_fp32_ops.add(op) + continue # processed below + for in_name in op.input_names: + if op.type in { + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + } and in_name not in {'X', 'Z'}: + continue + for in_var_name in op.input(in_name): + in_var = None + try: + in_var = block.var(in_var_name) + except ValueError as e: + _logger.debug( + "-- {}, try to get it in the global block --". + format(e)) + in_var = global_block.var(in_var_name) + if in_var is not None: + _logger.debug( + "-- var {} is got in the global block --". + format(in_var_name)) + + if in_var is None or in_var.type not in _valid_types: + continue + + if in_var.dtype == core.VarDesc.VarType.FP32: + if in_var.is_data: + to_bf16_pre_cast_ops.add(op) + else: + in_var.desc.set_dtype(core.VarDesc.VarType.BF16) + to_bf16_var_names.add(in_var_name) + + _logger.debug( + "-- op type: {}, in var name: {}, in var dtype: {} --". + format(op.type, in_var_name, in_var.dtype)) + + for out_name in op.output_names: + if op.type in { + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + } and out_name != 'Y': + continue + for out_var_name in op.output(out_name): + out_var = None + try: + out_var = block.var(out_var_name) + except ValueError as e: + _logger.debug( + "-- {}, try to get it in the global block --". + format(e)) + out_var = global_block.var(out_var_name) + if out_var is not None: + _logger.debug( + "-- var {} is got in the global block --". + format(out_var_name)) + + if out_var is None or out_var.type not in _valid_types: + continue + + if out_var.dtype == core.VarDesc.VarType.FP32: + out_var.desc.set_dtype(core.VarDesc.VarType.BF16) + + _logger.debug( + "-- op type: {}, out var name: {}, out var dtype: {} --". + format(op.type, out_var_name, out_var.dtype)) + if op.has_attr('in_dtype') and op.attr( + 'in_dtype') == core.VarDesc.VarType.FP32: + op._set_attr('in_dtype', core.VarDesc.VarType.BF16) + if op.has_attr('out_dtype') and op.attr( + 'out_dtype') == core.VarDesc.VarType.FP32: + op._set_attr('out_dtype', core.VarDesc.VarType.BF16) + if op.has_attr('dtype') and op.attr( + 'dtype') == core.VarDesc.VarType.FP32: + op._set_attr('dtype', core.VarDesc.VarType.BF16) + if op.has_attr('op_namescope') and \ + (_bf16_guard_pattern in op.attr("op_namescope")): + if op.has_attr('use_mkldnn'): + op._set_attr('use_mkldnn', True) + op._set_attr('mkldnn_data_type', 'bfloat16') + + # process ops in keep_fp32_ops + op_var_rename_map = [ + collections.OrderedDict() for _ in range(len(program.blocks)) + ] + for block in program.blocks: + ops = block.ops + idx = 0 + while idx < len(ops): + op = ops[idx] + num_cast_ops = 0 + if op not in keep_fp32_ops: + if op in to_bf16_pre_cast_ops: + in_var_cast_num = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.BF16) + num_cast_ops += in_var_cast_num + else: + pre_cast_num = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.BF16, + core.VarDesc.VarType.FP32) + num_cast_ops += pre_cast_num + for out_var_name in op.output_arg_names: + out_var = block.vars.get(out_var_name) + if out_var is None or out_var.type not in _valid_types: + continue + if out_var.dtype == core.VarDesc.VarType.BF16: + out_var.desc.set_dtype(core.VarDesc.VarType.FP32) + post_ops = find_true_post_op(ops, op, out_var_name) + for post_op in post_ops: + if post_op in keep_fp32_ops: + continue + post_cast_num = _insert_cast_post_op( + block, op, idx + pre_cast_num + 1, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.BF16, out_var_name, + op_var_rename_map) + num_cast_ops += post_cast_num + idx += num_cast_ops + 1 + + _rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops) + return to_bf16_var_names + + +def cast_parameters_to_bf16(place, program, scope=None, to_bf16_var_names=None): + """ + Traverse all parameters in the whole model and set them to the BF16 data type. + Whereas, this function will keep parameters of batchnorms in FP32. + Args: + place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the BF16 weight tensors. + program (Program): The used program. + scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values. + Default is None. + to_bf16_var_names(set|list, optional): The data types of vars in `to_bf16_var_names` + will be set to BF16. Usually, it is the returned + value of `cast_model_to_bf16` API. + """ + all_parameters = [] + for block in program.blocks: + all_parameters.extend(block.all_parameters()) + + bf16_var_names = to_bf16_var_names if to_bf16_var_names else set() + var_scope = scope if scope else global_scope() + for param in all_parameters: + if param.name in bf16_var_names: + _logger.debug("---- cast {} to bf16 dtype ----".format(param.name)) + param_t = var_scope.find_var(param.name).get_tensor() + data = np.array(param_t) + param_t.set(np.uint16(data), place) + + +def rewrite_program(main_prog, amp_lists): + """ + Traverse all ops in current block and insert cast op according to + which set current op belongs to. + + 1. When an op belongs to the black list, add it to black set + 2. When an op belongs to the white list, add it to white set + 3. When an op belongs to the gray list. If one + of its inputs is the output of black set op or black list op, + add it to black set. If all of its previous ops are not black + op and one of its inputs is the output of white set op or + white list op, add it to white set. + 4. When an op isn't in the lists, add it to black op set. + 5. Add necessary cast ops to make sure that black set op will be + computed in fp32 mode, while white set op will be computed in + bf16 mode. + + Args: + main_prog (Program): The main program for training. + """ + block = main_prog.global_block() + ops = block.ops + white_op_set = set() + black_op_set = set() + for op in ops: + + # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder, + # we don't need to handle reader op and the input of 'create_py_reader' is not + # in block, which may result in errors. + # See GeneratorLoader._init_non_iterable() for details. + if op.type == 'create_py_reader' or op.type == 'read': + continue + + if amp_lists.black_varnames is not None and _is_in_black_varnames( + op, amp_lists): + black_op_set.add(op) + continue + + if op.type in amp_lists.black_list: + black_op_set.add(op) + elif op.type in amp_lists.white_list: + white_op_set.add(op) + elif op.type in amp_lists.gray_list: + is_black_op = False + is_white_op = False + for in_name in op.input_names: + # if this op has inputs + if in_name: + for in_var_name in op.input(in_name): + in_var = block.var(in_var_name) + # this in_var isn't the output of other op + if in_var.op is None: + continue + elif in_var.op is op: + prev_op = find_true_prev_op(ops, op, in_var_name) + if prev_op is None: + continue + else: + prev_op = in_var.op + # if it's one of inputs + if prev_op in black_op_set or \ + prev_op.type in amp_lists.black_list: + is_black_op = True + elif prev_op in white_op_set or \ + prev_op.type in amp_lists.white_list: + is_white_op = True + if is_black_op: + black_op_set.add(op) + elif is_white_op: + white_op_set.add(op) + else: + pass + else: + # For numerical safe, we apply fp32 computation on ops that + # are not determined which list they should stay. + black_op_set.add(op) + + idx = 0 + while idx < len(ops): + op = ops[idx] + num_cast_ops = 0 + if op in black_op_set: + num_cast_ops = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.BF16, + core.VarDesc.VarType.FP32) + elif op in white_op_set: + num_cast_ops = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.BF16) + else: + pass + + idx += num_cast_ops + 1 + + +def update_role_var_grad(main_prog, params_grads): + """ + Update op_role_var attr for some ops to make sure the gradients + transferred across GPUs is BF16. + 1. Check whether the op that outputs gradient is cast or not. + 2. If op is cast and gradient is FP32, remove the op_role_var + and find the prev op which outputs BF16 gradient + 3. Update the op_role_var of the prev op. + + Args: + main_prog (Program): The main program for training. + params_grads (list): A list of params and grads. + """ + block = main_prog.global_block() + BACKWARD = core.op_proto_and_checker_maker.OpRole.Backward + OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize + for p, g in params_grads: + op = g.op + if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast': + role = op.attr('op_role') + if role & int(BACKWARD) and op.has_attr('op_role_var'): + op.desc.remove_attr("op_role_var") + else: + raise ValueError("The cast op {0} must be in BACKWARD role " + "and have op_role_var attr.".format(op)) + + bf16_grad_name = op.input(op.input_names[0])[0] + op_for_bf16_grad = find_true_prev_op(block.ops, op, bf16_grad_name) + op_role_var_attr_name = \ + core.op_proto_and_checker_maker.kOpRoleVarAttrName() + attr_val = [p.name, bf16_grad_name] + if op_for_bf16_grad.has_attr(op_role_var_attr_name): + attr_val.extend(op_for_bf16_grad.attr(op_role_var_attr_name)) + op_for_bf16_grad._set_attr(op_role_var_attr_name, attr_val) + + # Maximize the all_reduce overlap, and perform the cast + # operation after gradients transfer. + op._set_attr('op_role', OPTIMIZE) + # optimize op should stay behind forward and backward ops + if op == block.ops[-1]: + continue + post_ops = find_true_post_op(block.ops, op, g.name) + if post_ops: + raise ValueError("The cast op {0}'s output should not be" + "used by a non-optimize op, however, it" + "is used by {1}".format(op, post_ops[0])) + new_op_desc = block.desc.append_op() + new_op_desc.copy_from(op.desc) + + op_idx = find_op_index(block.desc, op.desc) + if op_idx == -1: + raise ValueError("The op {0} is not in program".format(op)) + block.desc._remove_op(op_idx, op_idx + 1) + block._sync_with_cpp() diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index d37e90b4695d0..b2b24f5a83599 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -24,7 +24,7 @@ from .fp16_utils import cast_model_to_fp16 from .fp16_utils import cast_parameters_to_fp16 from .fp16_utils import update_role_var_grad -from .fp16_lists import AutoMixedPrecisionLists +from .fp16_lists import AutoMixedPrecisionListsFP16 from .amp_nn import check_finite_and_unscale from .amp_nn import update_loss_scaling import types @@ -513,7 +513,7 @@ def run_example_code(): run_example_code() """ if amp_lists is None: - amp_lists = AutoMixedPrecisionLists() + amp_lists = AutoMixedPrecisionListsFP16() if use_fp16_guard is None: use_fp16_guard = use_pure_fp16 diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator_bf16.py b/python/paddle/fluid/contrib/mixed_precision/decorator_bf16.py new file mode 100644 index 0000000000000..9946039ca47b4 --- /dev/null +++ b/python/paddle/fluid/contrib/mixed_precision/decorator_bf16.py @@ -0,0 +1,526 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ... import core +from ... import default_main_program +from ... import default_startup_program +from ... import framework +from ... import layers +from ... import program_guard +from ... import unique_name +from . import bf16_utils +from .bf16_utils import rewrite_program +from .bf16_utils import cast_model_to_bf16 +from .bf16_utils import cast_parameters_to_bf16 +from .bf16_utils import update_role_var_grad +from .bf16_lists import AutoMixedPrecisionListsBF16 +from .amp_nn import check_finite_and_unscale +from .amp_nn import update_loss_scaling +import types +import warnings + +__all__ = ["decorate_bf16"] + + +class OptimizerWithMixedPrecision(object): + """ + Optimizer with mixed-precision (MP) training. This is a wrapper of a common + optimizer, plus the support of mixed-precision pre-training. The object + of this class almost has the same behavior as the common optimizer, with the + methods `minimize()`, `backward()`, `apply_gradients()` implemented. + Additionally, it enables the MP training automatically, i.e, the creation + and maintenance of master parameters, scaling of loss, etc. + + Args: + optimizer (Optimizer): A common Optimizer object. + amp_lists (CustomOpLists): An CustomOpLists object. + init_loss_scaling (float): The initial loss scaling factor. + use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling. + incr_every_n_steps(int): Increases loss scaling every n consecutive + steps with finite gradients. + decr_every_n_nan_or_inf(int): Decreases loss scaling every n + accumulated steps with nan or + inf gradients. + incr_ratio(float): The multiplier to use when increasing the loss + scaling. + decr_ratio(float): The less-than-one-multiplier to use when decreasing + the loss scaling. + use_pure_bf16(bool): Whether to use the pure bf16 training. Default False. + use_bf16_guard(bool): Whether to use `bf16_guard` when constructing the program. + Default None, which means that its value is equal to `use_pure_bf16`. + + """ + + def __init__(self, optimizer, amp_lists, init_loss_scaling, + use_dynamic_loss_scaling, incr_every_n_steps, + decr_every_n_nan_or_inf, incr_ratio, decr_ratio, use_pure_bf16, + use_bf16_guard): + self._optimizer = optimizer + self._amp_lists = amp_lists + self._param_grads = None + self._train_program = None + + self._is_distributed = False + self._scaled_loss = None + self._loss_scaling = None + self._init_loss_scaling = init_loss_scaling + self._use_dynamic_loss_scaling = use_dynamic_loss_scaling + self._learning_rate = optimizer._learning_rate + self._learning_rate_map = optimizer._learning_rate_map + self._use_pure_bf16 = use_pure_bf16 + self._use_bf16_guard = use_bf16_guard + self._to_bf16_var_names = None + if self._use_dynamic_loss_scaling: + self._incr_every_n_steps = incr_every_n_steps + self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf + self._incr_ratio = incr_ratio + self._decr_ratio = decr_ratio + self._num_good_steps = None + self._num_bad_steps = None + + def _set_distributed(self, flag): + # if distributed, all cards will communication with each other, + # overlap communication and computation by split the + # check_finite_and_unscale op. + self._is_distributed = flag + + def get_loss_scaling(self): + """Return the real-time loss scaling factor. + """ + return self._loss_scaling + + def get_scaled_loss(self): + """Return the scaled loss. + It's useful when you feed customed loss into executor. + """ + return self._scaled_loss + + def _init_amp_var(self): + self._loss_scaling = layers.create_global_var( + name=unique_name.generate("loss_scaling"), + shape=[1], + value=self._init_loss_scaling, + dtype='float32', + persistable=True) + + if self._use_dynamic_loss_scaling: + self._num_good_steps = layers.create_global_var( + name=unique_name.generate("num_good_steps"), + shape=[1], + value=0, + dtype='int32', + persistable=True) + self._num_bad_steps = layers.create_global_var( + name=unique_name.generate("num_bad_steps"), + shape=[1], + value=0, + dtype='int32', + persistable=True) + + # Ensure the data type of learning rate vars is float32 (same as the + # master parameter dtype) + if isinstance(self._optimizer._learning_rate, float): + self._optimizer._learning_rate_map[default_main_program()] = \ + layers.create_global_var( + name=unique_name.generate("learning_rate"), + shape=[1], + value=float(self._optimizer._learning_rate), + dtype='float32', + persistable=True) + + def backward(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None): + """ + Backward propagation or auto differentiation for gradients' computation. + + Args: + loss (Variable): The loss Variable to minimize. + startup_program (Program|None): The startup Program for initializing + parameters in `parameter_list`. + parameter_list (list|None): A list of Variables to update. + no_grad_set (set|None): A set of Variables should be ignored. + callbacks (list|None): A list of callable objects to run when appending + backward operator for one parameter. + + Returns: + A list of (param, grad), which is a tuple of a parameter and its + gradient respectively, and the scaled loss. + """ + train_program = loss.block.program + self._train_program = train_program + + with program_guard(self._train_program, startup_program): + self._init_amp_var() + + if self._use_pure_bf16: + self._to_bf16_var_names = cast_model_to_bf16( + self._train_program, self._amp_lists, self._use_bf16_guard) + else: + rewrite_program(self._train_program, self._amp_lists) + + if loss.dtype != core.VarDesc.VarType.FP32: + loss = loss.astype('float32') + # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0, + # the model can be optimized. + if self._use_dynamic_loss_scaling or self._init_loss_scaling != 1.0: + self._scaled_loss = loss * self._loss_scaling + else: + self._scaled_loss = loss + + params_grads = self._optimizer.backward( + self._scaled_loss, startup_program, parameter_list, no_grad_set, + callbacks) + return params_grads + + def amp_init(self, + place, + scope=None, + test_program=None, + use_bf16_test=False): + """ + Init the amp training, such as cast fp32 parameters to bf16 type. + + Args: + place(CUDAPlace): place is used to initialize + bf16 parameters with fp32 values. + scope(Scope): The scope is used to find fp32 parameters. + test_program(Program): The program is used for testing. + use_bf16_test(bool): Whether to use bf16 testing. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + import paddle.nn.functional as F + paddle.enable_static() + + def run_example_code(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32') + conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3) + # 1) Use bf16_guard to control the range of bf16 kernels used. + with paddle.static.amp.bf16_guard(): + bn = paddle.static.nn.batch_norm(input=conv2d, act="relu") + pool = F.max_pool2d(bn, kernel_size=2, stride=2) + hidden = paddle.static.nn.fc(pool, size=10) + loss = paddle.mean(hidden) + # 2) Create the optimizer and set `multi_precision` to True. + # Setting `multi_precision` to True can avoid the poor accuracy + # or the slow convergence in a way. + optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True) + # 3) These ops in `custom_black_list` will keep in the float32 computation type. + amp_list = paddle.static.amp.CustomOpLists( + custom_black_list=['pool2d']) + # 4) The entry of Paddle AMP. + # Enable pure bf16 training by setting `use_pure_bf16` to True. + optimizer = paddle.static.amp.decorate( + optimizer, + amp_list, + init_loss_scaling=128.0, + use_dynamic_loss_scaling=True, + use_pure_bf16=True) + # If you don't use the default_startup_program(), you sholud pass + # your defined `startup_program` into `minimize`. + optimizer.minimize(loss) + exe.run(paddle.static.default_startup_program()) + # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). + # If you want to perform the testing process, you should pass `test_program` into `amp_init`. + optimizer.amp_init(place, scope=paddle.static.global_scope()) + + if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: + run_example_code() + """ + assert self._train_program is not None, \ + "Please call the minimize method first." + if self._use_pure_bf16: + cast_parameters_to_bf16(place, self._train_program, scope, + self._to_bf16_var_names) + if test_program is not None: + if self._use_pure_bf16: + cast_model_to_bf16(test_program, self._amp_lists, + self._use_bf16_guard) + elif use_bf16_test: + rewrite_program(test_program, self._amp_lists) + + def apply_gradients(self, params_grads): + """ + Check scaled gradients to determine whether to update loss scaling and update + parameters by their scaled gradients. + + Args: + params_grads (list): A list of params and scaled grads. + + Returns: + A list of optimize operators. + """ + + # Change the op_role_var attr for some ops, so that gradients + # transferred across GPUs can be BF16. + update_role_var_grad(self._train_program, params_grads) + + # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0, + # the model can be optimized. + if not self._use_dynamic_loss_scaling and self._init_loss_scaling == 1.0: + return self._optimizer.apply_gradients(params_grads) + + grads = [g for _, g in params_grads] + fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] + bf16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.BF16] + assert len(fp32_grads) + len(bf16_grads) == len(grads), \ + "Data types of all grads must be either bf16 or fp32." + + found_infs = [] + if self._is_distributed: + # if distributed, split check_finite_and_unscale to overlap + # unscale with communication + for p, g in params_grads: + with self._train_program._optimized_guard([p, g]): + _, found_inf = check_finite_and_unscale( + [g, ], self._loss_scaling, name="find_infinite_scale") + found_infs.append(found_inf) + elif self._use_pure_bf16: + if fp32_grads: + with self._train_program._optimized_guard(fp32_grads): + _, fp32_found_inf = check_finite_and_unscale( + fp32_grads, + self._loss_scaling, + name="find_infinite_scale_fp32") + found_infs.append(fp32_found_inf) + if bf16_grads: + with self._train_program._optimized_guard(bf16_grads): + _, bf16_found_inf = check_finite_and_unscale( + bf16_grads, + self._loss_scaling, + name="find_infinite_scale_bf16") + found_infs.append(bf16_found_inf) + else: + with self._train_program._optimized_guard(grads): + _, found_inf = check_finite_and_unscale( + grads, self._loss_scaling, name="find_infinite_scale") + + if self._use_dynamic_loss_scaling: + if self._is_distributed or self._use_pure_bf16: + with self._train_program._optimized_guard([]): + all_infs = layers.concat(found_infs) + found_inf = layers.reduce_any(all_infs) + + if self._use_pure_bf16: + stop_update = False + with self._train_program._optimized_guard([]): + if fp32_grads: + update_loss_scaling( + fp32_grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + stop_update=stop_update, + name="update_loss_scaling_fp32") + stop_update = True + if bf16_grads: + update_loss_scaling( + bf16_grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + stop_update=stop_update, + name="update_loss_scaling_bf16") + else: + with self._train_program._optimized_guard([]): + update_loss_scaling( + grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + name="update_loss_scaling") + + optimize_ops = self._optimizer.apply_gradients(params_grads) + return optimize_ops + + def apply_optimize(self, loss, startup_program, params_grads): + program = loss.block.program + with program_guard(program, startup_program): + optimize_ops = self.apply_gradients(params_grads) + return optimize_ops + + def minimize(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + """ + Perform optimization by minimizing the given loss. + + Args: + loss (Variable): The loss Variable. + startup_program (Program): startup_program for initializing parameters + in `parameter_list`. + parameter_list (list): list of Variables to update. + no_grad_set (set|None): set of Variables should be ignored. + + Returns: + The scaled loss by scaling factor, the list of optimize ops, and a + list of scaled parameters and gradients. + """ + opt_dict = self._optimizer.__class__.__dict__ + if 'minimize' in opt_dict and isinstance(opt_dict['minimize'], + types.FunctionType): + warnings.warn( + "The decorated optimizer has its own `minimize` method, but it will not be executed." + ) + + scaled_params_grads = self.backward( + loss, + startup_program=startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set) + + optimize_ops = self.apply_optimize(loss, startup_program, + scaled_params_grads) + + return optimize_ops, scaled_params_grads + + +def decorate_bf16(optimizer, + amp_lists=None, + init_loss_scaling=2**15, + incr_every_n_steps=1000, + decr_every_n_nan_or_inf=2, + incr_ratio=2.0, + decr_ratio=0.8, + use_dynamic_loss_scaling=True, + use_pure_bf16=False, + use_bf16_guard=None): + """ + Decorate the given optimizer to adapt to the mixed-precision training. + + Args: + optimizer(Optimizer): A common Optimizer. + amp_lists (CustomOpLists): An CustomOpLists object. + init_loss_scaling(float): The initial loss scaling factor. + incr_every_n_steps(int): Increases loss scaling every n consecutive + steps with finite gradients. + decr_every_n_nan_or_inf(int): Decreases loss scaling every n + accumulated steps with nan or + inf gradients. + incr_ratio(float): The multiplier to use when increasing the loss + scaling. + decr_ratio(float): The less-than-one-multiplier to use when decreasing + the loss scaling. + use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. + use_pure_bf16(bool): Whether to use the pure bf16 training. Default False. + use_bf16_guard(bool): Whether to use `bf16_guard` when constructing the program. + Default None, which means that its value equals to `use_pure_bf16`. + + Returns: + An optimizer acting like a normal one but with mixed-precision training + enabled. + + Examples 1: + .. code-block:: python + + # black&white list based strategy example + import paddle + import paddle.static as static + + paddle.enable_static() + + data = static.data(name='X', shape=[None, 1], dtype='float32') + hidden = static.nn.fc(x=data, size=10) + loss = paddle.mean(hidden) + optimizer = paddle.optimizer.Adam(learning_rate=0.001) + + mp_optimizer = static.amp.decorate( + optimizer=optimizer, init_loss_scaling=8.0) + + ops, param_grads = mp_optimizer.minimize(loss) + scaled_loss = mp_optimizer.get_scaled_loss() + + Examples 2: + .. code-block:: python + + # pure bf16 training example + import numpy as np + import paddle + import paddle.nn.functional as F + + def run_example_code(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32') + conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3) + # 1) Use bf16_guard to control the range of bf16 kernels used. + with paddle.static.amp.bf16_guard(): + bn = paddle.static.nn.batch_norm(input=conv2d, act="relu") + pool = F.max_pool2d(bn, kernel_size=2, stride=2) + hidden = paddle.static.nn.fc(pool, size=10) + loss = paddle.mean(hidden) + # 2) Create the optimizer and set `multi_precision` to True. + # Setting `multi_precision` to True can avoid the poor accuracy + # or the slow convergence in a way. + optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True) + # 3) These ops in `custom_black_list` will keep in the float32 computation type. + amp_list = paddle.static.amp.CustomOpLists( + custom_black_list=['pool2d']) + # 4) The entry of Paddle AMP. + # Enable pure bf16 training by setting `use_pure_bf16` to True. + optimizer = paddle.static.amp.decorate( + optimizer, + amp_list, + init_loss_scaling=128.0, + use_dynamic_loss_scaling=True, + use_pure_bf16=True) + # If you don't use the default_startup_program(), you sholud pass + # your defined `startup_program` into `minimize`. + optimizer.minimize(loss) + exe.run(paddle.static.default_startup_program()) + # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). + # If you want to perform the testing process, you should pass `test_program` into `amp_init`. + optimizer.amp_init(place, scope=paddle.static.global_scope()) + + if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: + run_example_code() + """ + if amp_lists is None: + amp_lists = AutoMixedPrecisionListsBF16() + + if use_bf16_guard is None: + use_bf16_guard = use_pure_bf16 + + mp_optimizer = OptimizerWithMixedPrecision( + optimizer, amp_lists, init_loss_scaling, use_dynamic_loss_scaling, + incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, + use_pure_bf16, use_bf16_guard) + + return mp_optimizer diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index c88ae2d9cbf60..642a8b7b79848 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -14,7 +14,7 @@ import copy -__all__ = ["CustomOpLists", "AutoMixedPrecisionLists"] +__all__ = ["AutoMixedPrecisionListsFP16"] class AutoMixedPrecisionLists(object): @@ -25,21 +25,29 @@ class AutoMixedPrecisionLists(object): execution mode (fp32 or fp16). Args: + white_list_ (set): mode default white list. + black_list_ (set): mode default black list. + gray_list_ (set): mode default gray list. + unsupported_list_ (set): mode default unsupported list. custom_white_list (set): Users' custom white list. custom_black_list (set): Users' custom black list. custom_black_varnames (set): Users' custom black varibles' names. """ def __init__(self, + white_list_, + black_list_, + gray_list_, + unsupported_list_, custom_white_list=None, custom_black_list=None, custom_black_varnames=None): self._custom_white_list = custom_white_list self._custom_black_list = custom_black_list - self.white_list = copy.copy(white_list) - self.black_list = copy.copy(black_list) - self.gray_list = copy.copy(gray_list) - self.unsupported_list = copy.copy(unsupported_fp16_list) + self.white_list = copy.copy(white_list_) + self.black_list = copy.copy(black_list_) + self.gray_list = copy.copy(gray_list_) + self.unsupported_list = copy.copy(unsupported_list_) self.black_varnames = copy.copy(custom_black_varnames) self._update_list() @@ -69,7 +77,22 @@ def _update_list(self): self.unsupported_list.add(op_name) -# The three sets listed below are changed dynamiclly. They don't contain all +class AutoMixedPrecisionListsFP16(AutoMixedPrecisionLists): + def __init__(self, + custom_white_list=None, + custom_black_list=None, + custom_black_varnames=None): + super(AutoMixedPrecisionListsFP16, self).__init__( + white_list, + black_list, + gray_list, + unsupported_fp16_list, + custom_white_list=custom_white_list, + custom_black_list=custom_black_list, + custom_black_varnames=custom_black_varnames) + + +# The three sets listed below are changed dynamiclly. They don't contain all # paddle ops currently. # The set of ops that support fp16 calculation and are considered numerically- @@ -290,4 +313,4 @@ def _update_list(self): 'lookup_table_v2', } -CustomOpLists = AutoMixedPrecisionLists +CustomOpListsFP16 = AutoMixedPrecisionListsFP16 diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index f9c3a613c4053..474c4d36281d5 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -20,7 +20,7 @@ from ... import global_scope from ...log_helper import get_logger from ...wrapped_decorator import signature_safe_contextmanager -from .fp16_lists import AutoMixedPrecisionLists +from .fp16_lists import AutoMixedPrecisionListsFP16 import collections import logging import numpy as np @@ -317,7 +317,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): """ if amp_lists is None: - amp_lists = AutoMixedPrecisionLists() + amp_lists = AutoMixedPrecisionListsFP16() global_block = program.global_block() keep_fp32_ops = set() to_fp16_var_names = set() diff --git a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py index 0280dfcf67b1d..3255364911562 100644 --- a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py +++ b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py @@ -137,7 +137,7 @@ def train(net_type, use_cuda, save_dirname, is_local): optimizer = fluid.optimizer.Lamb(learning_rate=0.001) - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( custom_black_varnames={"loss", "conv2d_0.w_0"}) mp_optimizer = decorate( optimizer=optimizer, @@ -282,7 +282,7 @@ def test_amp_lists(self): gray_list = copy.copy( fluid.contrib.mixed_precision.fp16_lists.gray_list) - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists() + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16() self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) self.assertEqual(amp_lists.gray_list, gray_list) @@ -299,7 +299,7 @@ def test_amp_lists_1(self): white_list.add('exp') black_list.remove('exp') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( {'exp'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -317,7 +317,7 @@ def test_amp_lists_2(self): white_list.add('tanh') gray_list.remove('tanh') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( {'tanh'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -334,7 +334,7 @@ def test_amp_lists_3(self): # 3. w={'lstm'}, b=None white_list.add('lstm') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( {'lstm'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -352,7 +352,7 @@ def test_amp_lists_4(self): white_list.remove('conv2d') black_list.add('conv2d') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( custom_black_list={'conv2d'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -370,7 +370,7 @@ def test_amp_lists_5(self): black_list.add('tanh') gray_list.remove('tanh') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( custom_black_list={'tanh'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -387,7 +387,7 @@ def test_amp_lists_6(self): # 6. w=None, b={'lstm'} black_list.add('lstm') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( custom_black_list={'lstm'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -396,9 +396,10 @@ def test_amp_lists_6(self): def test_amp_lists_7(self): # 7. w={'lstm'} b={'lstm'} # raise ValueError - self.assertRaises(ValueError, - fluid.contrib.mixed_precision.AutoMixedPrecisionLists, - {'lstm'}, {'lstm'}) + self.assertRaises( + ValueError, + fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16, + {'lstm'}, {'lstm'}) def test_vgg_cuda(self): with self.scope_prog_guard(): @@ -441,7 +442,7 @@ def decorate_with_data_loader(self): avg_cost = fluid.layers.mean(cost) optimizer = fluid.optimizer.Lamb(learning_rate=0.001) - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( custom_black_varnames={"loss", "conv2d_0.w_0"}) mp_optimizer = decorate( optimizer=optimizer, diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py new file mode 100644 index 0000000000000..ba176d08f8341 --- /dev/null +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -0,0 +1,101 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import contextlib +import unittest +import numpy as np +import paddle.fluid.layers as layers +from paddle.fluid import core +from paddle.fluid.contrib.mixed_precision.bf16_utils import cast_model_to_bf16 +from paddle.fluid.tests.unittests.op_test import convert_float_to_uint16, convert_uint16_to_float + +paddle.enable_static() + + +class TestImageMultiPrecision(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.seed = 111 + + @classmethod + def tearDownClass(cls): + pass + + @contextlib.contextmanager + def static_graph(self): + with self.scope_prog_guard(): + paddle.seed(self.seed) + paddle.framework.random._manual_program_seed(self.seed) + yield + + @contextlib.contextmanager + def scope_prog_guard(self): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + yield + + def get_static_graph_result(self, feed, fetch_list, with_lod=False): + exe = fluid.Executor(core.CPUPlace()) + exe.run(fluid.default_startup_program()) + prog = fluid.default_main_program() + cast_model_to_bf16(prog, use_bf16_guard=True) + return exe.run(prog, + feed=feed, + fetch_list=fetch_list, + return_numpy=(not with_lod)) + + def test_elementwise_math(self): + size = 3 + n = np.ones([size, size], dtype='float32') * 3.2 + nn = np.ones([size, size], dtype='float32') * -2.7 + + n_bf16 = convert_float_to_uint16(n) + nn_bf16 = convert_float_to_uint16(nn) + + with self.static_graph(): + t_bf16 = layers.data( + name='t_bf16', shape=[size, size], dtype=np.uint16) + tt_bf16 = layers.data( + name='tt_bf16', shape=[size, size], dtype=np.uint16) + t = layers.data(name='t', shape=[size, size], dtype='float32') + tt = layers.data(name='tt', shape=[size, size], dtype='float32') + + ret = layers.elementwise_add(t, tt) + ret = layers.elementwise_mul(ret, t) + with paddle.static.amp.bf16_guard(): + ret_bf16 = layers.elementwise_add(t_bf16, tt_bf16) + ret_bf16 = layers.elementwise_mul(ret_bf16, t_bf16) + + static_ret_bf16, static_ret = self.get_static_graph_result( + feed={ + 't': n, + 'tt': nn, + 't_bf16': n_bf16, + 'tt_bf16': nn_bf16, + }, + fetch_list=[ret_bf16, ret]) + + stt = convert_uint16_to_float(static_ret_bf16) + self.assertTrue(np.allclose(stt, static_ret, 1e-2)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index b2db00296bf95..f693e250a4a14 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -29,6 +29,7 @@ _PADDLE_DTYPE_2_NUMPY_DTYPE = { core.VarDesc.VarType.BOOL: 'bool', core.VarDesc.VarType.FP16: 'float16', + core.VarDesc.VarType.BF16: 'uint16', core.VarDesc.VarType.FP32: 'float32', core.VarDesc.VarType.FP64: 'float64', core.VarDesc.VarType.INT8: 'int8', @@ -47,16 +48,18 @@ def convert_dtype(dtype): return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] elif isinstance(dtype, type): if dtype in [ - np.bool, np.float16, np.float32, np.float64, np.int8, np.int16, - np.int32, np.int64, np.uint8, np.complex64, np.complex128 + np.bool, np.float16, np.uint16, np.float32, np.float64, np.int8, + np.int16, np.int32, np.int64, np.uint8, np.complex64, + np.complex128 ]: return dtype.__name__ else: if dtype in [ - 'bool', 'float16', 'float32', 'float64', 'int8', 'int16', - 'int32', 'int64', 'uint8', 'complex64', 'complex128', u'bool', - u'float16', u'float32', u'float64', u'int8', u'int16', u'int32', - u'int64', u'uint8', u'complex64', u'complex128' + 'bool', 'float16', 'uint16', 'float32', 'float64', 'int8', + 'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', + u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8', + u'int16', u'int32', u'int64', u'uint8', u'complex64', + u'complex128' ]: # this code is a little bit dangerous, since error could happen # when casting no-ascii code to str in python2. @@ -66,7 +69,7 @@ def convert_dtype(dtype): return str(dtype) raise TypeError( - "dtype must be any of [bool, float16, float32, float64, int8, int16, " + "dtype must be any of [bool, float16, uint16, float32, float64, int8, int16, " "int32, int64, uint8, complex64, complex128], but received %s" % dtype) @@ -123,6 +126,10 @@ def check_dtype(input_dtype, warnings.warn( "The data type of '%s' in %s only support float16 in GPU now. %s" % (input_name, op_name, extra_message)) + if convert_dtype(input_dtype) in ['uint16']: + warnings.warn( + "The data type of '%s' in %s only support bfloat16 in OneDNN now. %s" + % (input_name, op_name, extra_message)) if convert_dtype(input_dtype) not in expected_dtype: raise TypeError( "The data type of '%s' in %s must be %s, but received %s. %s" % diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index fa8df14c8669b..d6f79186c8509 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -11354,9 +11354,11 @@ def _elementwise_op(helper): assert x is not None, 'x cannot be None in {}'.format(op_type) assert y is not None, 'y cannot be None in {}'.format(op_type) check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], op_type) + x, 'x', ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], + op_type) check_variable_and_dtype( - y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64'], op_type) + y, 'y', ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], + op_type) axis = helper.kwargs.get('axis', -1) use_mkldnn = helper.kwargs.get('use_mkldnn', False) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 939e2ac0f59fd..dff96a8cbc3c4 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -244,17 +244,12 @@ def convert_float_to_uint16(float_list, data_format="NCHW"): return new_output -def copy_bits_from_uint16_to_float(i): - i = np.uint32(i) << 16 - return struct.unpack(' Date: Wed, 24 Feb 2021 20:27:12 +0100 Subject: [PATCH 02/33] Updates for CI --- .../fluid/contrib/mixed_precision/__init__.py | 2 +- .../contrib/mixed_precision/bf16_lists.py | 18 +++++++------ .../contrib/mixed_precision/decorator.py | 4 +-- .../contrib/mixed_precision/fp16_lists.py | 10 ++++---- .../contrib/mixed_precision/fp16_utils.py | 4 +-- .../tests/test_image_classification_fp16.py | 25 +++++++++---------- 6 files changed, 32 insertions(+), 31 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/__init__.py b/python/paddle/fluid/contrib/mixed_precision/__init__.py index baab9e167c069..5b7efda36a74c 100644 --- a/python/paddle/fluid/contrib/mixed_precision/__init__.py +++ b/python/paddle/fluid/contrib/mixed_precision/__init__.py @@ -28,7 +28,7 @@ from .bf16_utils import * __all__ = decorator.__all__ -__all__ = decorator_bf16.__all__ +__all__ += decorator_bf16.__all__ __all__ += fp16_lists.__all__ __all__ += bf16_lists.__all__ __all__ += fp16_utils.__all__ diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16_lists.py b/python/paddle/fluid/contrib/mixed_precision/bf16_lists.py index b41cb10622785..0aa169724fc80 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16_lists.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .fp16_lists import AutoMixedPrecisionLists, \ +from .fp16_lists import AutoMixedPrecisionListsCommon, \ white_list as white_list_fp16, black_list as black_list_fp16, \ gray_list as gray_list_fp16, unsupported_fp16_list __all__ = ["AutoMixedPrecisionListsBF16"] -class AutoMixedPrecisionListsBF16(AutoMixedPrecisionLists): +class AutoMixedPrecisionListsBF16(AutoMixedPrecisionListsCommon): def __init__(self, custom_white_list=None, custom_black_list=None, @@ -34,11 +34,13 @@ def __init__(self, custom_black_varnames=custom_black_varnames) -white_list = {'elementwise_add'} +white_list = {'elementwise_add', } + black_list = black_list_fp16.copy().copy() -black_list.update(white_list_fp16) -black_list.update(gray_list_fp16) -gray_list = set() -unsupported_list = unsupported_fp16_list +black_list |= white_list_fp16 +black_list |= gray_list_fp16 +black_list -= white_list -CustomOpListsBF16 = AutoMixedPrecisionListsBF16 +gray_list = set() +unsupported_list = unsupported_fp16_list.copy().copy() +unsupported_list -= white_list diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index b2b24f5a83599..d37e90b4695d0 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -24,7 +24,7 @@ from .fp16_utils import cast_model_to_fp16 from .fp16_utils import cast_parameters_to_fp16 from .fp16_utils import update_role_var_grad -from .fp16_lists import AutoMixedPrecisionListsFP16 +from .fp16_lists import AutoMixedPrecisionLists from .amp_nn import check_finite_and_unscale from .amp_nn import update_loss_scaling import types @@ -513,7 +513,7 @@ def run_example_code(): run_example_code() """ if amp_lists is None: - amp_lists = AutoMixedPrecisionListsFP16() + amp_lists = AutoMixedPrecisionLists() if use_fp16_guard is None: use_fp16_guard = use_pure_fp16 diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index 642a8b7b79848..6a2d72a58e0fe 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -14,10 +14,10 @@ import copy -__all__ = ["AutoMixedPrecisionListsFP16"] +__all__ = ["CustomOpLists", "AutoMixedPrecisionLists"] -class AutoMixedPrecisionLists(object): +class AutoMixedPrecisionListsCommon(object): """ AutoMixedPrecisionLists is a class for black/white list. It can update pre-defined black list and white list according to users' custom black @@ -77,12 +77,12 @@ def _update_list(self): self.unsupported_list.add(op_name) -class AutoMixedPrecisionListsFP16(AutoMixedPrecisionLists): +class AutoMixedPrecisionLists(AutoMixedPrecisionListsCommon): def __init__(self, custom_white_list=None, custom_black_list=None, custom_black_varnames=None): - super(AutoMixedPrecisionListsFP16, self).__init__( + super(AutoMixedPrecisionLists, self).__init__( white_list, black_list, gray_list, @@ -313,4 +313,4 @@ def __init__(self, 'lookup_table_v2', } -CustomOpListsFP16 = AutoMixedPrecisionListsFP16 +CustomOpLists = AutoMixedPrecisionLists diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 474c4d36281d5..f9c3a613c4053 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -20,7 +20,7 @@ from ... import global_scope from ...log_helper import get_logger from ...wrapped_decorator import signature_safe_contextmanager -from .fp16_lists import AutoMixedPrecisionListsFP16 +from .fp16_lists import AutoMixedPrecisionLists import collections import logging import numpy as np @@ -317,7 +317,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): """ if amp_lists is None: - amp_lists = AutoMixedPrecisionListsFP16() + amp_lists = AutoMixedPrecisionLists() global_block = program.global_block() keep_fp32_ops = set() to_fp16_var_names = set() diff --git a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py index 3255364911562..0280dfcf67b1d 100644 --- a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py +++ b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py @@ -137,7 +137,7 @@ def train(net_type, use_cuda, save_dirname, is_local): optimizer = fluid.optimizer.Lamb(learning_rate=0.001) - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( custom_black_varnames={"loss", "conv2d_0.w_0"}) mp_optimizer = decorate( optimizer=optimizer, @@ -282,7 +282,7 @@ def test_amp_lists(self): gray_list = copy.copy( fluid.contrib.mixed_precision.fp16_lists.gray_list) - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16() + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists() self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) self.assertEqual(amp_lists.gray_list, gray_list) @@ -299,7 +299,7 @@ def test_amp_lists_1(self): white_list.add('exp') black_list.remove('exp') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( {'exp'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -317,7 +317,7 @@ def test_amp_lists_2(self): white_list.add('tanh') gray_list.remove('tanh') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( {'tanh'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -334,7 +334,7 @@ def test_amp_lists_3(self): # 3. w={'lstm'}, b=None white_list.add('lstm') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( {'lstm'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -352,7 +352,7 @@ def test_amp_lists_4(self): white_list.remove('conv2d') black_list.add('conv2d') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( custom_black_list={'conv2d'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -370,7 +370,7 @@ def test_amp_lists_5(self): black_list.add('tanh') gray_list.remove('tanh') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( custom_black_list={'tanh'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -387,7 +387,7 @@ def test_amp_lists_6(self): # 6. w=None, b={'lstm'} black_list.add('lstm') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( custom_black_list={'lstm'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -396,10 +396,9 @@ def test_amp_lists_6(self): def test_amp_lists_7(self): # 7. w={'lstm'} b={'lstm'} # raise ValueError - self.assertRaises( - ValueError, - fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16, - {'lstm'}, {'lstm'}) + self.assertRaises(ValueError, + fluid.contrib.mixed_precision.AutoMixedPrecisionLists, + {'lstm'}, {'lstm'}) def test_vgg_cuda(self): with self.scope_prog_guard(): @@ -442,7 +441,7 @@ def decorate_with_data_loader(self): avg_cost = fluid.layers.mean(cost) optimizer = fluid.optimizer.Lamb(learning_rate=0.001) - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsFP16( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( custom_black_varnames={"loss", "conv2d_0.w_0"}) mp_optimizer = decorate( optimizer=optimizer, From 9fc6899de1e66dc31a6363a579a051300e206155 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Thu, 25 Feb 2021 12:58:39 +0100 Subject: [PATCH 03/33] More updates for CI --- .../contrib/mixed_precision/bf16_utils.py | 22 ++++++++++++++++++- .../contrib/tests/test_model_cast_to_bf16.py | 4 ++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py index e8e8372b066f2..4ba7eea541a38 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import print_function +import struct from ... import core from ... import framework @@ -25,7 +26,10 @@ import logging import numpy as np -__all__ = ["bf16_guard", "cast_model_to_bf16", "cast_parameters_to_bf16"] +__all__ = [ + "bf16_guard", "cast_model_to_bf16", "cast_parameters_to_bf16", + "convert_float_to_uint16", "convert_uint16_to_float" +] _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') @@ -38,6 +42,22 @@ _bf16_guard_pattern = "__use_bf16__" +def convert_float_to_uint16(in_list): + in_list = np.asarray(in_list) + out = np.vectorize( + lambda x: struct.unpack('> 16, + otypes=[np.uint16])(in_list.flat) + return np.reshape(out, in_list.shape) + + +def convert_uint16_to_float(in_list): + in_list = np.asarray(in_list) + out = np.vectorize( + lambda x: struct.unpack(' Date: Thu, 25 Feb 2021 15:09:43 +0100 Subject: [PATCH 04/33] More updates for CI --- python/paddle/fluid/contrib/tests/CMakeLists.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/paddle/fluid/contrib/tests/CMakeLists.txt b/python/paddle/fluid/contrib/tests/CMakeLists.txt index a28588bfa5382..050610f738b67 100644 --- a/python/paddle/fluid/contrib/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/tests/CMakeLists.txt @@ -2,6 +2,7 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") list(REMOVE_ITEM TEST_OPS test_multi_precision_fp16_train) +list(REMOVE_ITEM TEST_OPS test_model_cast_to_bf16) foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) @@ -12,3 +13,8 @@ py_test_modules(test_multi_precision_fp16_train MODULES test_multi_precision_fp1 set_tests_properties(test_image_classification_fp16 PROPERTIES TIMEOUT 120) set_tests_properties(test_weight_decay_extend PROPERTIES TIMEOUT 120) set_tests_properties(test_multi_precision_fp16_train PROPERTIES TIMEOUT 120) + +if(WITH_MKLDNN) + py_test_modules(test_model_cast_to_bf16 MODULES test_model_cast_to_bf16) + set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 120) +endif() From d532ca08fd5a4f7a233a66dd377360066774d387 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Thu, 25 Feb 2021 16:50:50 +0100 Subject: [PATCH 05/33] Added test for bf16_utils --- .../paddle/fluid/contrib/tests/CMakeLists.txt | 6 +- .../fluid/contrib/tests/test_bf16_utils.py | 172 ++++++++++++++++++ 2 files changed, 176 insertions(+), 2 deletions(-) create mode 100644 python/paddle/fluid/contrib/tests/test_bf16_utils.py diff --git a/python/paddle/fluid/contrib/tests/CMakeLists.txt b/python/paddle/fluid/contrib/tests/CMakeLists.txt index 050610f738b67..833573f52152a 100644 --- a/python/paddle/fluid/contrib/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/tests/CMakeLists.txt @@ -3,6 +3,7 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") list(REMOVE_ITEM TEST_OPS test_multi_precision_fp16_train) list(REMOVE_ITEM TEST_OPS test_model_cast_to_bf16) +list(REMOVE_ITEM TEST_OPS test_bf16_utils) foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) @@ -15,6 +16,7 @@ set_tests_properties(test_weight_decay_extend PROPERTIES TIMEOUT 120) set_tests_properties(test_multi_precision_fp16_train PROPERTIES TIMEOUT 120) if(WITH_MKLDNN) - py_test_modules(test_model_cast_to_bf16 MODULES test_model_cast_to_bf16) - set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 120) + py_test(test_bf16_utils SRCS test_bf16_utils.py) + py_test_modules(test_model_cast_to_bf16 MODULES test_model_cast_to_bf16) + set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 120) endif() diff --git a/python/paddle/fluid/contrib/tests/test_bf16_utils.py b/python/paddle/fluid/contrib/tests/test_bf16_utils.py new file mode 100644 index 0000000000000..11135a62c1f7a --- /dev/null +++ b/python/paddle/fluid/contrib/tests/test_bf16_utils.py @@ -0,0 +1,172 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import unittest +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid.contrib.mixed_precision import bf16_utils +import paddle + +paddle.enable_static() + + +class AMPTest(unittest.TestCase): + def test_amp_lists(self): + white_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.white_list) + black_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.black_list) + gray_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.gray_list) + + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists() + self.assertEqual(amp_lists.white_list, white_list) + self.assertEqual(amp_lists.black_list, black_list) + self.assertEqual(amp_lists.gray_list, gray_list) + + def test_amp_lists_1(self): + white_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.white_list) + black_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.black_list) + gray_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.gray_list) + + # 1. w={'exp}, b=None + white_list.add('exp') + black_list.remove('exp') + + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + {'exp'}) + self.assertEqual(amp_lists.white_list, white_list) + self.assertEqual(amp_lists.black_list, black_list) + self.assertEqual(amp_lists.gray_list, gray_list) + + def test_amp_lists_2(self): + white_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.white_list) + black_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.black_list) + gray_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.gray_list) + + # 2. w={'tanh'}, b=None + white_list.add('tanh') + gray_list.remove('tanh') + + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + {'tanh'}) + self.assertEqual(amp_lists.white_list, white_list) + self.assertEqual(amp_lists.black_list, black_list) + self.assertEqual(amp_lists.gray_list, gray_list) + + def test_amp_lists_3(self): + white_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.white_list) + black_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.black_list) + gray_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.gray_list) + + # 3. w={'lstm'}, b=None + white_list.add('lstm') + + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + {'lstm'}) + self.assertEqual(amp_lists.white_list, white_list) + self.assertEqual(amp_lists.black_list, black_list) + self.assertEqual(amp_lists.gray_list, gray_list) + + def test_amp_lists_4(self): + white_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.white_list) + black_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.black_list) + gray_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.gray_list) + + # 4. w=None, b={'conv2d'} + white_list.remove('conv2d') + black_list.add('conv2d') + + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + custom_black_list={'conv2d'}) + self.assertEqual(amp_lists.white_list, white_list) + self.assertEqual(amp_lists.black_list, black_list) + self.assertEqual(amp_lists.gray_list, gray_list) + + def test_amp_lists_5(self): + white_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.white_list) + black_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.black_list) + gray_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.gray_list) + + # 5. w=None, b={'tanh'} + black_list.add('tanh') + gray_list.remove('tanh') + + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + custom_black_list={'tanh'}) + self.assertEqual(amp_lists.white_list, white_list) + self.assertEqual(amp_lists.black_list, black_list) + self.assertEqual(amp_lists.gray_list, gray_list) + + def test_amp_lists_6(self): + white_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.white_list) + black_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.black_list) + gray_list = copy.copy( + fluid.contrib.mixed_precision.fp16_lists.gray_list) + + # 6. w=None, b={'lstm'} + black_list.add('lstm') + + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + custom_black_list={'lstm'}) + self.assertEqual(amp_lists.white_list, white_list) + self.assertEqual(amp_lists.black_list, black_list) + self.assertEqual(amp_lists.gray_list, gray_list) + + def test_amp_lists_7(self): + # 7. w={'lstm'} b={'lstm'} + # raise ValueError + self.assertRaises(ValueError, + fluid.contrib.mixed_precision.AutoMixedPrecisionLists, + {'lstm'}, {'lstm'}) + + def test_find_op_index(self): + block = fluid.default_main_program().global_block() + op_desc = core.OpDesc() + idx = bf16_utils.find_op_index(block.desc, op_desc) + assert (idx == -1) + + def test_find_true_post_op(self): + block = fluid.default_main_program().global_block() + + var1 = block.create_var(name="X", shape=[3], dtype='float32') + var2 = block.create_var(name="Y", shape=[3], dtype='float32') + var3 = block.create_var(name="Z", shape=[3], dtype='float32') + op1 = block.append_op( + type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]}) + op2 = block.append_op( + type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]}) + res = bf16_utils.find_true_post_op(block.ops, op1, "Y") + assert (res == [op2]) + + +if __name__ == '__main__': + unittest.main() From deb275fdef396e53dd3536b89f36f41380dc2593 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Thu, 25 Feb 2021 19:27:43 +0100 Subject: [PATCH 06/33] Added test for bf16_utils --- .../paddle/fluid/contrib/tests/CMakeLists.txt | 2 - .../fluid/contrib/tests/test_bf16_utils.py | 81 ++++++++++--------- 2 files changed, 41 insertions(+), 42 deletions(-) diff --git a/python/paddle/fluid/contrib/tests/CMakeLists.txt b/python/paddle/fluid/contrib/tests/CMakeLists.txt index 833573f52152a..779cf33b6b8b9 100644 --- a/python/paddle/fluid/contrib/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/tests/CMakeLists.txt @@ -3,7 +3,6 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") list(REMOVE_ITEM TEST_OPS test_multi_precision_fp16_train) list(REMOVE_ITEM TEST_OPS test_model_cast_to_bf16) -list(REMOVE_ITEM TEST_OPS test_bf16_utils) foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) @@ -16,7 +15,6 @@ set_tests_properties(test_weight_decay_extend PROPERTIES TIMEOUT 120) set_tests_properties(test_multi_precision_fp16_train PROPERTIES TIMEOUT 120) if(WITH_MKLDNN) - py_test(test_bf16_utils SRCS test_bf16_utils.py) py_test_modules(test_model_cast_to_bf16 MODULES test_model_cast_to_bf16) set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 120) endif() diff --git a/python/paddle/fluid/contrib/tests/test_bf16_utils.py b/python/paddle/fluid/contrib/tests/test_bf16_utils.py index 11135a62c1f7a..94afa03e97ce3 100644 --- a/python/paddle/fluid/contrib/tests/test_bf16_utils.py +++ b/python/paddle/fluid/contrib/tests/test_bf16_utils.py @@ -24,30 +24,30 @@ class AMPTest(unittest.TestCase): def test_amp_lists(self): white_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.white_list) + fluid.contrib.mixed_precision.bf16_lists.white_list) black_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.black_list) + fluid.contrib.mixed_precision.bf16_lists.black_list) gray_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.gray_list) + fluid.contrib.mixed_precision.bf16_lists.gray_list) - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists() + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsBF16() self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) self.assertEqual(amp_lists.gray_list, gray_list) def test_amp_lists_1(self): white_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.white_list) + fluid.contrib.mixed_precision.bf16_lists.white_list) black_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.black_list) + fluid.contrib.mixed_precision.bf16_lists.black_list) gray_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.gray_list) + fluid.contrib.mixed_precision.bf16_lists.gray_list) # 1. w={'exp}, b=None white_list.add('exp') black_list.remove('exp') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsBF16( {'exp'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -55,17 +55,17 @@ def test_amp_lists_1(self): def test_amp_lists_2(self): white_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.white_list) + fluid.contrib.mixed_precision.bf16_lists.white_list) black_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.black_list) + fluid.contrib.mixed_precision.bf16_lists.black_list) gray_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.gray_list) + fluid.contrib.mixed_precision.bf16_lists.gray_list) # 2. w={'tanh'}, b=None + black_list.remove('tanh') white_list.add('tanh') - gray_list.remove('tanh') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsBF16( {'tanh'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -73,16 +73,16 @@ def test_amp_lists_2(self): def test_amp_lists_3(self): white_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.white_list) + fluid.contrib.mixed_precision.bf16_lists.white_list) black_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.black_list) + fluid.contrib.mixed_precision.bf16_lists.black_list) gray_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.gray_list) + fluid.contrib.mixed_precision.bf16_lists.gray_list) # 3. w={'lstm'}, b=None white_list.add('lstm') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsBF16( {'lstm'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -90,52 +90,52 @@ def test_amp_lists_3(self): def test_amp_lists_4(self): white_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.white_list) + fluid.contrib.mixed_precision.bf16_lists.white_list) black_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.black_list) + fluid.contrib.mixed_precision.bf16_lists.black_list) gray_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.gray_list) + fluid.contrib.mixed_precision.bf16_lists.gray_list) - # 4. w=None, b={'conv2d'} - white_list.remove('conv2d') - black_list.add('conv2d') + # 4. w=None, b={'elementwise_add'} + white_list.remove('elementwise_add') + black_list.add('elementwise_add') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( - custom_black_list={'conv2d'}) + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsBF16( + custom_black_list={'elementwise_add'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) self.assertEqual(amp_lists.gray_list, gray_list) def test_amp_lists_5(self): white_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.white_list) + fluid.contrib.mixed_precision.bf16_lists.white_list) black_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.black_list) + fluid.contrib.mixed_precision.bf16_lists.black_list) gray_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.gray_list) + fluid.contrib.mixed_precision.bf16_lists.gray_list) - # 5. w=None, b={'tanh'} - black_list.add('tanh') - gray_list.remove('tanh') + # 5. w=None, b={'elementwise_add'} + black_list.add('elementwise_add') + white_list.remove('elementwise_add') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( - custom_black_list={'tanh'}) + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsBF16( + custom_black_list={'elementwise_add'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) self.assertEqual(amp_lists.gray_list, gray_list) def test_amp_lists_6(self): white_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.white_list) + fluid.contrib.mixed_precision.bf16_lists.white_list) black_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.black_list) + fluid.contrib.mixed_precision.bf16_lists.black_list) gray_list = copy.copy( - fluid.contrib.mixed_precision.fp16_lists.gray_list) + fluid.contrib.mixed_precision.bf16_lists.gray_list) # 6. w=None, b={'lstm'} black_list.add('lstm') - amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionListsBF16( custom_black_list={'lstm'}) self.assertEqual(amp_lists.white_list, white_list) self.assertEqual(amp_lists.black_list, black_list) @@ -144,9 +144,10 @@ def test_amp_lists_6(self): def test_amp_lists_7(self): # 7. w={'lstm'} b={'lstm'} # raise ValueError - self.assertRaises(ValueError, - fluid.contrib.mixed_precision.AutoMixedPrecisionLists, - {'lstm'}, {'lstm'}) + self.assertRaises( + ValueError, + fluid.contrib.mixed_precision.AutoMixedPrecisionListsBF16, + {'lstm'}, {'lstm'}) def test_find_op_index(self): block = fluid.default_main_program().global_block() From a3fdf514b88a4280378e04251e54069b35b2f9cf Mon Sep 17 00:00:00 2001 From: arlesniak Date: Fri, 26 Feb 2021 09:51:01 +0100 Subject: [PATCH 07/33] Changes for CI --- python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py | 4 +++- tools/parallel_UT_rule.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py index 713b74ebfd307..7c29e2fd732c4 100644 --- a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -27,7 +27,9 @@ paddle.enable_static() -class TestImageMultiPrecision(unittest.TestCase): +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestModelCastBF16(unittest.TestCase): @classmethod def setUpClass(cls): cls.seed = 111 diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index a5239e534e2f5..3fb78b0d0a19a 100644 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -219,6 +219,7 @@ 'test_full_op', 'test_framework_debug_str', 'test_fp16_utils', + 'test_bf16_utils', 'test_fleet_rolemaker_4', 'test_flags_use_mkldnn', 'test_filter_by_instag_op', From 204760a36549e26c543b4fa67586a31ae1df5816 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Sun, 28 Feb 2021 20:00:50 +0100 Subject: [PATCH 08/33] Changes for CI, more tests --- .../fluid/contrib/mixed_precision/amp_nn.py | 3 - .../contrib/mixed_precision/bf16_utils.py | 84 ++---- .../contrib/mixed_precision/decorator_bf16.py | 260 ++---------------- .../contrib/tests/test_fit_a_line_bf16.py | 133 +++++++++ 4 files changed, 171 insertions(+), 309 deletions(-) create mode 100644 python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py diff --git a/python/paddle/fluid/contrib/mixed_precision/amp_nn.py b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py index 16b117bf0f4e7..3bfc078971d7a 100644 --- a/python/paddle/fluid/contrib/mixed_precision/amp_nn.py +++ b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py @@ -97,9 +97,6 @@ def update_loss_scaling(x, if e.dtype == core.VarDesc.VarType.FP16: assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \ "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." - elif e.dtype == core.VarDesc.VarType.BF16: - assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \ - "The dtype of prev_loss_scaling should be float32 when the dtype of x is bfloat16." else: assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x." diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py index 4ba7eea541a38..c64bd9703c8f7 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py @@ -424,11 +424,9 @@ def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True): if op.has_attr('dtype') and op.attr( 'dtype') == core.VarDesc.VarType.FP32: op._set_attr('dtype', core.VarDesc.VarType.BF16) - if op.has_attr('op_namescope') and \ - (_bf16_guard_pattern in op.attr("op_namescope")): - if op.has_attr('use_mkldnn'): - op._set_attr('use_mkldnn', True) - op._set_attr('mkldnn_data_type', 'bfloat16') + if op.has_attr('use_mkldnn'): + op._set_attr('use_mkldnn', True) + op._set_attr('mkldnn_data_type', 'bfloat16') # process ops in keep_fp32_ops op_var_rename_map = [ @@ -500,7 +498,7 @@ def cast_parameters_to_bf16(place, program, scope=None, to_bf16_var_names=None): param_t.set(np.uint16(data), place) -def rewrite_program(main_prog, amp_lists): +def rewrite_program(main_prog, amp_lists, use_bf16_guard): """ Traverse all ops in current block and insert cast op according to which set current op belongs to. @@ -586,66 +584,22 @@ def rewrite_program(main_prog, amp_lists): core.VarDesc.VarType.BF16, core.VarDesc.VarType.FP32) elif op in white_op_set: - num_cast_ops = _insert_cast_op(block, op, idx, - core.VarDesc.VarType.FP32, - core.VarDesc.VarType.BF16) + if use_bf16_guard: + if not (op.has_attr('op_namescope') and + (_bf16_guard_pattern in op.attr("op_namescope"))): + idx += 1 + continue + if op.has_attr('use_mkldnn'): + op._set_attr('use_mkldnn', True) + op._set_attr('mkldnn_data_type', 'bfloat16') + elif op.has_attr('dtype') and op.attr( + 'dtype') == core.VarDesc.VarType.FP32: + op._set_attr('dtype', core.VarDesc.VarType.BF16) + + num_cast_ops = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.BF16) else: pass idx += num_cast_ops + 1 - - -def update_role_var_grad(main_prog, params_grads): - """ - Update op_role_var attr for some ops to make sure the gradients - transferred across GPUs is BF16. - 1. Check whether the op that outputs gradient is cast or not. - 2. If op is cast and gradient is FP32, remove the op_role_var - and find the prev op which outputs BF16 gradient - 3. Update the op_role_var of the prev op. - - Args: - main_prog (Program): The main program for training. - params_grads (list): A list of params and grads. - """ - block = main_prog.global_block() - BACKWARD = core.op_proto_and_checker_maker.OpRole.Backward - OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize - for p, g in params_grads: - op = g.op - if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast': - role = op.attr('op_role') - if role & int(BACKWARD) and op.has_attr('op_role_var'): - op.desc.remove_attr("op_role_var") - else: - raise ValueError("The cast op {0} must be in BACKWARD role " - "and have op_role_var attr.".format(op)) - - bf16_grad_name = op.input(op.input_names[0])[0] - op_for_bf16_grad = find_true_prev_op(block.ops, op, bf16_grad_name) - op_role_var_attr_name = \ - core.op_proto_and_checker_maker.kOpRoleVarAttrName() - attr_val = [p.name, bf16_grad_name] - if op_for_bf16_grad.has_attr(op_role_var_attr_name): - attr_val.extend(op_for_bf16_grad.attr(op_role_var_attr_name)) - op_for_bf16_grad._set_attr(op_role_var_attr_name, attr_val) - - # Maximize the all_reduce overlap, and perform the cast - # operation after gradients transfer. - op._set_attr('op_role', OPTIMIZE) - # optimize op should stay behind forward and backward ops - if op == block.ops[-1]: - continue - post_ops = find_true_post_op(block.ops, op, g.name) - if post_ops: - raise ValueError("The cast op {0}'s output should not be" - "used by a non-optimize op, however, it" - "is used by {1}".format(op, post_ops[0])) - new_op_desc = block.desc.append_op() - new_op_desc.copy_from(op.desc) - - op_idx = find_op_index(block.desc, op.desc) - if op_idx == -1: - raise ValueError("The op {0} is not in program".format(op)) - block.desc._remove_op(op_idx, op_idx + 1) - block._sync_with_cpp() diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator_bf16.py b/python/paddle/fluid/contrib/mixed_precision/decorator_bf16.py index 9946039ca47b4..b83b45cd5d46e 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator_bf16.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator_bf16.py @@ -14,19 +14,11 @@ from ... import core from ... import default_main_program -from ... import default_startup_program -from ... import framework from ... import layers from ... import program_guard from ... import unique_name -from . import bf16_utils from .bf16_utils import rewrite_program -from .bf16_utils import cast_model_to_bf16 -from .bf16_utils import cast_parameters_to_bf16 -from .bf16_utils import update_role_var_grad from .bf16_lists import AutoMixedPrecisionListsBF16 -from .amp_nn import check_finite_and_unscale -from .amp_nn import update_loss_scaling import types import warnings @@ -45,89 +37,22 @@ class OptimizerWithMixedPrecision(object): Args: optimizer (Optimizer): A common Optimizer object. amp_lists (CustomOpLists): An CustomOpLists object. - init_loss_scaling (float): The initial loss scaling factor. - use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling. - incr_every_n_steps(int): Increases loss scaling every n consecutive - steps with finite gradients. - decr_every_n_nan_or_inf(int): Decreases loss scaling every n - accumulated steps with nan or - inf gradients. - incr_ratio(float): The multiplier to use when increasing the loss - scaling. - decr_ratio(float): The less-than-one-multiplier to use when decreasing - the loss scaling. - use_pure_bf16(bool): Whether to use the pure bf16 training. Default False. use_bf16_guard(bool): Whether to use `bf16_guard` when constructing the program. Default None, which means that its value is equal to `use_pure_bf16`. """ - def __init__(self, optimizer, amp_lists, init_loss_scaling, - use_dynamic_loss_scaling, incr_every_n_steps, - decr_every_n_nan_or_inf, incr_ratio, decr_ratio, use_pure_bf16, - use_bf16_guard): + def __init__(self, optimizer, amp_lists, use_bf16_guard): self._optimizer = optimizer self._amp_lists = amp_lists - self._param_grads = None self._train_program = None - self._is_distributed = False - self._scaled_loss = None - self._loss_scaling = None - self._init_loss_scaling = init_loss_scaling - self._use_dynamic_loss_scaling = use_dynamic_loss_scaling self._learning_rate = optimizer._learning_rate self._learning_rate_map = optimizer._learning_rate_map - self._use_pure_bf16 = use_pure_bf16 self._use_bf16_guard = use_bf16_guard self._to_bf16_var_names = None - if self._use_dynamic_loss_scaling: - self._incr_every_n_steps = incr_every_n_steps - self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf - self._incr_ratio = incr_ratio - self._decr_ratio = decr_ratio - self._num_good_steps = None - self._num_bad_steps = None - - def _set_distributed(self, flag): - # if distributed, all cards will communication with each other, - # overlap communication and computation by split the - # check_finite_and_unscale op. - self._is_distributed = flag - - def get_loss_scaling(self): - """Return the real-time loss scaling factor. - """ - return self._loss_scaling - - def get_scaled_loss(self): - """Return the scaled loss. - It's useful when you feed customed loss into executor. - """ - return self._scaled_loss def _init_amp_var(self): - self._loss_scaling = layers.create_global_var( - name=unique_name.generate("loss_scaling"), - shape=[1], - value=self._init_loss_scaling, - dtype='float32', - persistable=True) - - if self._use_dynamic_loss_scaling: - self._num_good_steps = layers.create_global_var( - name=unique_name.generate("num_good_steps"), - shape=[1], - value=0, - dtype='int32', - persistable=True) - self._num_bad_steps = layers.create_global_var( - name=unique_name.generate("num_bad_steps"), - shape=[1], - value=0, - dtype='int32', - persistable=True) - # Ensure the data type of learning rate vars is float32 (same as the # master parameter dtype) if isinstance(self._optimizer._learning_rate, float): @@ -167,31 +92,21 @@ def backward(self, with program_guard(self._train_program, startup_program): self._init_amp_var() - if self._use_pure_bf16: - self._to_bf16_var_names = cast_model_to_bf16( - self._train_program, self._amp_lists, self._use_bf16_guard) - else: - rewrite_program(self._train_program, self._amp_lists) + rewrite_program(self._train_program, self._amp_lists, + self._use_bf16_guard) if loss.dtype != core.VarDesc.VarType.FP32: loss = loss.astype('float32') # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0, # the model can be optimized. - if self._use_dynamic_loss_scaling or self._init_loss_scaling != 1.0: - self._scaled_loss = loss * self._loss_scaling - else: - self._scaled_loss = loss + self._scaled_loss = loss params_grads = self._optimizer.backward( self._scaled_loss, startup_program, parameter_list, no_grad_set, callbacks) return params_grads - def amp_init(self, - place, - scope=None, - test_program=None, - use_bf16_test=False): + def amp_init(self, test_program=None, use_bf16_test=False): """ Init the amp training, such as cast fp32 parameters to bf16 type. @@ -211,7 +126,6 @@ def amp_init(self, paddle.enable_static() def run_example_code(): - place = paddle.CUDAPlace(0) exe = paddle.static.Executor(place) data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32') conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3) @@ -230,34 +144,26 @@ def run_example_code(): custom_black_list=['pool2d']) # 4) The entry of Paddle AMP. # Enable pure bf16 training by setting `use_pure_bf16` to True. - optimizer = paddle.static.amp.decorate( + optimizer = paddle.static.amp.decorate_bf16( optimizer, - amp_list, - init_loss_scaling=128.0, - use_dynamic_loss_scaling=True, - use_pure_bf16=True) + amp_list) # If you don't use the default_startup_program(), you sholud pass # your defined `startup_program` into `minimize`. optimizer.minimize(loss) exe.run(paddle.static.default_startup_program()) # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). # If you want to perform the testing process, you should pass `test_program` into `amp_init`. - optimizer.amp_init(place, scope=paddle.static.global_scope()) + optimizer.amp_init() if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: run_example_code() """ assert self._train_program is not None, \ "Please call the minimize method first." - if self._use_pure_bf16: - cast_parameters_to_bf16(place, self._train_program, scope, - self._to_bf16_var_names) if test_program is not None: - if self._use_pure_bf16: - cast_model_to_bf16(test_program, self._amp_lists, - self._use_bf16_guard) - elif use_bf16_test: - rewrite_program(test_program, self._amp_lists) + if use_bf16_test: + rewrite_program(test_program, self._amp_lists, + self._use_bf16_guard) def apply_gradients(self, params_grads): """ @@ -271,102 +177,9 @@ def apply_gradients(self, params_grads): A list of optimize operators. """ - # Change the op_role_var attr for some ops, so that gradients - # transferred across GPUs can be BF16. - update_role_var_grad(self._train_program, params_grads) - # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0, # the model can be optimized. - if not self._use_dynamic_loss_scaling and self._init_loss_scaling == 1.0: - return self._optimizer.apply_gradients(params_grads) - - grads = [g for _, g in params_grads] - fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] - bf16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.BF16] - assert len(fp32_grads) + len(bf16_grads) == len(grads), \ - "Data types of all grads must be either bf16 or fp32." - - found_infs = [] - if self._is_distributed: - # if distributed, split check_finite_and_unscale to overlap - # unscale with communication - for p, g in params_grads: - with self._train_program._optimized_guard([p, g]): - _, found_inf = check_finite_and_unscale( - [g, ], self._loss_scaling, name="find_infinite_scale") - found_infs.append(found_inf) - elif self._use_pure_bf16: - if fp32_grads: - with self._train_program._optimized_guard(fp32_grads): - _, fp32_found_inf = check_finite_and_unscale( - fp32_grads, - self._loss_scaling, - name="find_infinite_scale_fp32") - found_infs.append(fp32_found_inf) - if bf16_grads: - with self._train_program._optimized_guard(bf16_grads): - _, bf16_found_inf = check_finite_and_unscale( - bf16_grads, - self._loss_scaling, - name="find_infinite_scale_bf16") - found_infs.append(bf16_found_inf) - else: - with self._train_program._optimized_guard(grads): - _, found_inf = check_finite_and_unscale( - grads, self._loss_scaling, name="find_infinite_scale") - - if self._use_dynamic_loss_scaling: - if self._is_distributed or self._use_pure_bf16: - with self._train_program._optimized_guard([]): - all_infs = layers.concat(found_infs) - found_inf = layers.reduce_any(all_infs) - - if self._use_pure_bf16: - stop_update = False - with self._train_program._optimized_guard([]): - if fp32_grads: - update_loss_scaling( - fp32_grads, - found_inf, - self._loss_scaling, - self._num_good_steps, - self._num_bad_steps, - self._incr_every_n_steps, - self._decr_every_n_nan_or_inf, - self._incr_ratio, - self._decr_ratio, - stop_update=stop_update, - name="update_loss_scaling_fp32") - stop_update = True - if bf16_grads: - update_loss_scaling( - bf16_grads, - found_inf, - self._loss_scaling, - self._num_good_steps, - self._num_bad_steps, - self._incr_every_n_steps, - self._decr_every_n_nan_or_inf, - self._incr_ratio, - self._decr_ratio, - stop_update=stop_update, - name="update_loss_scaling_bf16") - else: - with self._train_program._optimized_guard([]): - update_loss_scaling( - grads, - found_inf, - self._loss_scaling, - self._num_good_steps, - self._num_bad_steps, - self._incr_every_n_steps, - self._decr_every_n_nan_or_inf, - self._incr_ratio, - self._decr_ratio, - name="update_loss_scaling") - - optimize_ops = self._optimizer.apply_gradients(params_grads) - return optimize_ops + return self._optimizer.apply_gradients(params_grads) def apply_optimize(self, loss, startup_program, params_grads): program = loss.block.program @@ -412,34 +225,13 @@ def minimize(self, return optimize_ops, scaled_params_grads -def decorate_bf16(optimizer, - amp_lists=None, - init_loss_scaling=2**15, - incr_every_n_steps=1000, - decr_every_n_nan_or_inf=2, - incr_ratio=2.0, - decr_ratio=0.8, - use_dynamic_loss_scaling=True, - use_pure_bf16=False, - use_bf16_guard=None): +def decorate_bf16(optimizer, amp_lists=None, use_bf16_guard=None): """ Decorate the given optimizer to adapt to the mixed-precision training. Args: optimizer(Optimizer): A common Optimizer. amp_lists (CustomOpLists): An CustomOpLists object. - init_loss_scaling(float): The initial loss scaling factor. - incr_every_n_steps(int): Increases loss scaling every n consecutive - steps with finite gradients. - decr_every_n_nan_or_inf(int): Decreases loss scaling every n - accumulated steps with nan or - inf gradients. - incr_ratio(float): The multiplier to use when increasing the loss - scaling. - decr_ratio(float): The less-than-one-multiplier to use when decreasing - the loss scaling. - use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. - use_pure_bf16(bool): Whether to use the pure bf16 training. Default False. use_bf16_guard(bool): Whether to use `bf16_guard` when constructing the program. Default None, which means that its value equals to `use_pure_bf16`. @@ -461,11 +253,9 @@ def decorate_bf16(optimizer, loss = paddle.mean(hidden) optimizer = paddle.optimizer.Adam(learning_rate=0.001) - mp_optimizer = static.amp.decorate( - optimizer=optimizer, init_loss_scaling=8.0) + mp_optimizer = static.amp.decorate(optimizer=optimizer) ops, param_grads = mp_optimizer.minimize(loss) - scaled_loss = mp_optimizer.get_scaled_loss() Examples 2: .. code-block:: python @@ -476,7 +266,6 @@ def decorate_bf16(optimizer, import paddle.nn.functional as F def run_example_code(): - place = paddle.CUDAPlace(0) exe = paddle.static.Executor(place) data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32') conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3) @@ -494,33 +283,22 @@ def run_example_code(): amp_list = paddle.static.amp.CustomOpLists( custom_black_list=['pool2d']) # 4) The entry of Paddle AMP. - # Enable pure bf16 training by setting `use_pure_bf16` to True. optimizer = paddle.static.amp.decorate( optimizer, - amp_list, - init_loss_scaling=128.0, - use_dynamic_loss_scaling=True, - use_pure_bf16=True) - # If you don't use the default_startup_program(), you sholud pass + amp_list) + # If you don't use the default_startup_program(), you should pass # your defined `startup_program` into `minimize`. optimizer.minimize(loss) exe.run(paddle.static.default_startup_program()) # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). # If you want to perform the testing process, you should pass `test_program` into `amp_init`. - optimizer.amp_init(place, scope=paddle.static.global_scope()) + optimizer.amp_init() - if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: - run_example_code() """ if amp_lists is None: amp_lists = AutoMixedPrecisionListsBF16() - if use_bf16_guard is None: - use_bf16_guard = use_pure_bf16 - - mp_optimizer = OptimizerWithMixedPrecision( - optimizer, amp_lists, init_loss_scaling, use_dynamic_loss_scaling, - incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, - use_pure_bf16, use_bf16_guard) + mp_optimizer = OptimizerWithMixedPrecision(optimizer, amp_lists, + use_bf16_guard) return mp_optimizer diff --git a/python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py b/python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py new file mode 100644 index 0000000000000..36f25758ae4c8 --- /dev/null +++ b/python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py @@ -0,0 +1,133 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import contextlib +import numpy +import unittest + +paddle.enable_static() + + +def train(is_local): + x = fluid.layers.data(name='x', shape=[13], dtype='float32') + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + + with paddle.static.amp.bf16_guard(): + y_predict = fluid.layers.fc(input=x, size=1, act=None) + + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) + sgd_optimizer = paddle.static.amp.decorate_bf16( + sgd_optimizer, use_bf16_guard=True) + sgd_optimizer.minimize(avg_cost) + + BATCH_SIZE = 20 + + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.uci_housing.train(), buf_size=500), + batch_size=BATCH_SIZE) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + def train_loop(main_program): + feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) + exe.run(fluid.default_startup_program()) + sgd_optimizer.amp_init(exe.place) + + PASS_NUM = 1 + for pass_id in range(PASS_NUM): + for data in train_reader(): + avg_loss_value, = exe.run(main_program, + feed=feeder.feed(data), + fetch_list=[avg_cost]) + print(avg_loss_value) + + if is_local: + train_loop(fluid.default_main_program()) + + +def infer(save_dirname=None): + if save_dirname is None: + return + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + inference_scope = fluid.core.Scope() + with fluid.scope_guard(inference_scope): + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be fed + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + # The input's dimension should be 2-D and the second dim is 13 + # The input data should be >= 0 + batch_size = 10 + + test_reader = paddle.batch( + paddle.dataset.uci_housing.test(), batch_size=batch_size) + + test_data = next(test_reader()) + test_feat = numpy.array( + [data[0] for data in test_data]).astype("float32") + test_label = numpy.array( + [data[1] for data in test_data]).astype("float32") + + assert feed_target_names[0] == 'x' + results = exe.run(inference_program, + feed={feed_target_names[0]: numpy.array(test_feat)}, + fetch_list=fetch_targets) + print("infer shape: ", results[0].shape) + print("infer results: ", results[0]) + print("ground truth: ", test_label) + + +def main(): + if not fluid.core.is_compiled_with_mkldnn(): + return + + # Directory for saving the trained model + save_dirname = "fit_a_line.inference.model" + + train(save_dirname) + infer(save_dirname) + + +class TestFitALine(unittest.TestCase): + def test_cpu(self): + with self.program_scope_guard(): + main() + + @contextlib.contextmanager + def program_scope_guard(self): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + yield + + +if __name__ == '__main__': + unittest.main() From 4e5555f682a5b2a8ea2a928e572b04665b25950c Mon Sep 17 00:00:00 2001 From: arlesniak Date: Mon, 1 Mar 2021 08:01:30 +0100 Subject: [PATCH 09/33] Changes for CI --- .../fluid/contrib/mixed_precision/__init__.py | 3 - .../contrib/mixed_precision/bf16_utils.py | 180 +---------- .../contrib/mixed_precision/decorator_bf16.py | 304 ------------------ .../contrib/tests/test_fit_a_line_bf16.py | 133 -------- 4 files changed, 1 insertion(+), 619 deletions(-) delete mode 100644 python/paddle/fluid/contrib/mixed_precision/decorator_bf16.py delete mode 100644 python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py diff --git a/python/paddle/fluid/contrib/mixed_precision/__init__.py b/python/paddle/fluid/contrib/mixed_precision/__init__.py index 5b7efda36a74c..d246e04949ece 100644 --- a/python/paddle/fluid/contrib/mixed_precision/__init__.py +++ b/python/paddle/fluid/contrib/mixed_precision/__init__.py @@ -16,8 +16,6 @@ from . import decorator from .decorator import * -from . import decorator_bf16 -from .decorator_bf16 import * from . import fp16_lists from .fp16_lists import * from . import bf16_lists @@ -28,7 +26,6 @@ from .bf16_utils import * __all__ = decorator.__all__ -__all__ += decorator_bf16.__all__ __all__ += fp16_lists.__all__ __all__ += bf16_lists.__all__ __all__ += fp16_utils.__all__ diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py index c64bd9703c8f7..88d1cca5ac699 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py @@ -26,10 +26,7 @@ import logging import numpy as np -__all__ = [ - "bf16_guard", "cast_model_to_bf16", "cast_parameters_to_bf16", - "convert_float_to_uint16", "convert_uint16_to_float" -] +__all__ = ["bf16_guard", "convert_float_to_uint16", "convert_uint16_to_float"] _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') @@ -323,181 +320,6 @@ def bf16_guard(): yield -def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True): - """ - Traverse all ops in the whole model and set their inputs and outputs - to the bf16 data type. This function will do some special process for - the batch normalization, which keeps the computational process of - batchnorms in FP32. - Args: - program (Program): The used program. - amp_lists (AutoMixedPrecisionListsBF16): An AutoMixedPrecisionListsBF16 object. - use_bf16_guard(bool): Determine whether to use `bf16_guard` when - constructing the program. Default True. - """ - - if amp_lists is None: - amp_lists = AutoMixedPrecisionListsBF16() - global_block = program.global_block() - keep_fp32_ops = set() - to_bf16_var_names = set() - to_bf16_pre_cast_ops = set() - origin_ops = [] - for block in program.blocks: - origin_ops.extend(block.ops) - - for block in program.blocks: - ops = block.ops - for op in ops: - if op.type == 'create_py_reader' or op.type == 'read': - continue - if _need_keep_fp32(op, amp_lists.unsupported_list, use_bf16_guard): - keep_fp32_ops.add(op) - continue # processed below - for in_name in op.input_names: - if op.type in { - 'batch_norm', 'fused_bn_add_activation', 'layer_norm' - } and in_name not in {'X', 'Z'}: - continue - for in_var_name in op.input(in_name): - in_var = None - try: - in_var = block.var(in_var_name) - except ValueError as e: - _logger.debug( - "-- {}, try to get it in the global block --". - format(e)) - in_var = global_block.var(in_var_name) - if in_var is not None: - _logger.debug( - "-- var {} is got in the global block --". - format(in_var_name)) - - if in_var is None or in_var.type not in _valid_types: - continue - - if in_var.dtype == core.VarDesc.VarType.FP32: - if in_var.is_data: - to_bf16_pre_cast_ops.add(op) - else: - in_var.desc.set_dtype(core.VarDesc.VarType.BF16) - to_bf16_var_names.add(in_var_name) - - _logger.debug( - "-- op type: {}, in var name: {}, in var dtype: {} --". - format(op.type, in_var_name, in_var.dtype)) - - for out_name in op.output_names: - if op.type in { - 'batch_norm', 'fused_bn_add_activation', 'layer_norm' - } and out_name != 'Y': - continue - for out_var_name in op.output(out_name): - out_var = None - try: - out_var = block.var(out_var_name) - except ValueError as e: - _logger.debug( - "-- {}, try to get it in the global block --". - format(e)) - out_var = global_block.var(out_var_name) - if out_var is not None: - _logger.debug( - "-- var {} is got in the global block --". - format(out_var_name)) - - if out_var is None or out_var.type not in _valid_types: - continue - - if out_var.dtype == core.VarDesc.VarType.FP32: - out_var.desc.set_dtype(core.VarDesc.VarType.BF16) - - _logger.debug( - "-- op type: {}, out var name: {}, out var dtype: {} --". - format(op.type, out_var_name, out_var.dtype)) - if op.has_attr('in_dtype') and op.attr( - 'in_dtype') == core.VarDesc.VarType.FP32: - op._set_attr('in_dtype', core.VarDesc.VarType.BF16) - if op.has_attr('out_dtype') and op.attr( - 'out_dtype') == core.VarDesc.VarType.FP32: - op._set_attr('out_dtype', core.VarDesc.VarType.BF16) - if op.has_attr('dtype') and op.attr( - 'dtype') == core.VarDesc.VarType.FP32: - op._set_attr('dtype', core.VarDesc.VarType.BF16) - if op.has_attr('use_mkldnn'): - op._set_attr('use_mkldnn', True) - op._set_attr('mkldnn_data_type', 'bfloat16') - - # process ops in keep_fp32_ops - op_var_rename_map = [ - collections.OrderedDict() for _ in range(len(program.blocks)) - ] - for block in program.blocks: - ops = block.ops - idx = 0 - while idx < len(ops): - op = ops[idx] - num_cast_ops = 0 - if op not in keep_fp32_ops: - if op in to_bf16_pre_cast_ops: - in_var_cast_num = _insert_cast_op(block, op, idx, - core.VarDesc.VarType.FP32, - core.VarDesc.VarType.BF16) - num_cast_ops += in_var_cast_num - else: - pre_cast_num = _insert_cast_op(block, op, idx, - core.VarDesc.VarType.BF16, - core.VarDesc.VarType.FP32) - num_cast_ops += pre_cast_num - for out_var_name in op.output_arg_names: - out_var = block.vars.get(out_var_name) - if out_var is None or out_var.type not in _valid_types: - continue - if out_var.dtype == core.VarDesc.VarType.BF16: - out_var.desc.set_dtype(core.VarDesc.VarType.FP32) - post_ops = find_true_post_op(ops, op, out_var_name) - for post_op in post_ops: - if post_op in keep_fp32_ops: - continue - post_cast_num = _insert_cast_post_op( - block, op, idx + pre_cast_num + 1, - core.VarDesc.VarType.FP32, - core.VarDesc.VarType.BF16, out_var_name, - op_var_rename_map) - num_cast_ops += post_cast_num - idx += num_cast_ops + 1 - - _rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops) - return to_bf16_var_names - - -def cast_parameters_to_bf16(place, program, scope=None, to_bf16_var_names=None): - """ - Traverse all parameters in the whole model and set them to the BF16 data type. - Whereas, this function will keep parameters of batchnorms in FP32. - Args: - place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the BF16 weight tensors. - program (Program): The used program. - scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values. - Default is None. - to_bf16_var_names(set|list, optional): The data types of vars in `to_bf16_var_names` - will be set to BF16. Usually, it is the returned - value of `cast_model_to_bf16` API. - """ - all_parameters = [] - for block in program.blocks: - all_parameters.extend(block.all_parameters()) - - bf16_var_names = to_bf16_var_names if to_bf16_var_names else set() - var_scope = scope if scope else global_scope() - for param in all_parameters: - if param.name in bf16_var_names: - _logger.debug("---- cast {} to bf16 dtype ----".format(param.name)) - param_t = var_scope.find_var(param.name).get_tensor() - data = np.array(param_t) - param_t.set(np.uint16(data), place) - - def rewrite_program(main_prog, amp_lists, use_bf16_guard): """ Traverse all ops in current block and insert cast op according to diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator_bf16.py b/python/paddle/fluid/contrib/mixed_precision/decorator_bf16.py deleted file mode 100644 index b83b45cd5d46e..0000000000000 --- a/python/paddle/fluid/contrib/mixed_precision/decorator_bf16.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ... import core -from ... import default_main_program -from ... import layers -from ... import program_guard -from ... import unique_name -from .bf16_utils import rewrite_program -from .bf16_lists import AutoMixedPrecisionListsBF16 -import types -import warnings - -__all__ = ["decorate_bf16"] - - -class OptimizerWithMixedPrecision(object): - """ - Optimizer with mixed-precision (MP) training. This is a wrapper of a common - optimizer, plus the support of mixed-precision pre-training. The object - of this class almost has the same behavior as the common optimizer, with the - methods `minimize()`, `backward()`, `apply_gradients()` implemented. - Additionally, it enables the MP training automatically, i.e, the creation - and maintenance of master parameters, scaling of loss, etc. - - Args: - optimizer (Optimizer): A common Optimizer object. - amp_lists (CustomOpLists): An CustomOpLists object. - use_bf16_guard(bool): Whether to use `bf16_guard` when constructing the program. - Default None, which means that its value is equal to `use_pure_bf16`. - - """ - - def __init__(self, optimizer, amp_lists, use_bf16_guard): - self._optimizer = optimizer - self._amp_lists = amp_lists - self._train_program = None - - self._learning_rate = optimizer._learning_rate - self._learning_rate_map = optimizer._learning_rate_map - self._use_bf16_guard = use_bf16_guard - self._to_bf16_var_names = None - - def _init_amp_var(self): - # Ensure the data type of learning rate vars is float32 (same as the - # master parameter dtype) - if isinstance(self._optimizer._learning_rate, float): - self._optimizer._learning_rate_map[default_main_program()] = \ - layers.create_global_var( - name=unique_name.generate("learning_rate"), - shape=[1], - value=float(self._optimizer._learning_rate), - dtype='float32', - persistable=True) - - def backward(self, - loss, - startup_program=None, - parameter_list=None, - no_grad_set=None, - callbacks=None): - """ - Backward propagation or auto differentiation for gradients' computation. - - Args: - loss (Variable): The loss Variable to minimize. - startup_program (Program|None): The startup Program for initializing - parameters in `parameter_list`. - parameter_list (list|None): A list of Variables to update. - no_grad_set (set|None): A set of Variables should be ignored. - callbacks (list|None): A list of callable objects to run when appending - backward operator for one parameter. - - Returns: - A list of (param, grad), which is a tuple of a parameter and its - gradient respectively, and the scaled loss. - """ - train_program = loss.block.program - self._train_program = train_program - - with program_guard(self._train_program, startup_program): - self._init_amp_var() - - rewrite_program(self._train_program, self._amp_lists, - self._use_bf16_guard) - - if loss.dtype != core.VarDesc.VarType.FP32: - loss = loss.astype('float32') - # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0, - # the model can be optimized. - self._scaled_loss = loss - - params_grads = self._optimizer.backward( - self._scaled_loss, startup_program, parameter_list, no_grad_set, - callbacks) - return params_grads - - def amp_init(self, test_program=None, use_bf16_test=False): - """ - Init the amp training, such as cast fp32 parameters to bf16 type. - - Args: - place(CUDAPlace): place is used to initialize - bf16 parameters with fp32 values. - scope(Scope): The scope is used to find fp32 parameters. - test_program(Program): The program is used for testing. - use_bf16_test(bool): Whether to use bf16 testing. - - Examples: - .. code-block:: python - - import numpy as np - import paddle - import paddle.nn.functional as F - paddle.enable_static() - - def run_example_code(): - exe = paddle.static.Executor(place) - data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32') - conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3) - # 1) Use bf16_guard to control the range of bf16 kernels used. - with paddle.static.amp.bf16_guard(): - bn = paddle.static.nn.batch_norm(input=conv2d, act="relu") - pool = F.max_pool2d(bn, kernel_size=2, stride=2) - hidden = paddle.static.nn.fc(pool, size=10) - loss = paddle.mean(hidden) - # 2) Create the optimizer and set `multi_precision` to True. - # Setting `multi_precision` to True can avoid the poor accuracy - # or the slow convergence in a way. - optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True) - # 3) These ops in `custom_black_list` will keep in the float32 computation type. - amp_list = paddle.static.amp.CustomOpLists( - custom_black_list=['pool2d']) - # 4) The entry of Paddle AMP. - # Enable pure bf16 training by setting `use_pure_bf16` to True. - optimizer = paddle.static.amp.decorate_bf16( - optimizer, - amp_list) - # If you don't use the default_startup_program(), you sholud pass - # your defined `startup_program` into `minimize`. - optimizer.minimize(loss) - exe.run(paddle.static.default_startup_program()) - # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). - # If you want to perform the testing process, you should pass `test_program` into `amp_init`. - optimizer.amp_init() - - if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: - run_example_code() - """ - assert self._train_program is not None, \ - "Please call the minimize method first." - if test_program is not None: - if use_bf16_test: - rewrite_program(test_program, self._amp_lists, - self._use_bf16_guard) - - def apply_gradients(self, params_grads): - """ - Check scaled gradients to determine whether to update loss scaling and update - parameters by their scaled gradients. - - Args: - params_grads (list): A list of params and scaled grads. - - Returns: - A list of optimize operators. - """ - - # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0, - # the model can be optimized. - return self._optimizer.apply_gradients(params_grads) - - def apply_optimize(self, loss, startup_program, params_grads): - program = loss.block.program - with program_guard(program, startup_program): - optimize_ops = self.apply_gradients(params_grads) - return optimize_ops - - def minimize(self, - loss, - startup_program=None, - parameter_list=None, - no_grad_set=None): - """ - Perform optimization by minimizing the given loss. - - Args: - loss (Variable): The loss Variable. - startup_program (Program): startup_program for initializing parameters - in `parameter_list`. - parameter_list (list): list of Variables to update. - no_grad_set (set|None): set of Variables should be ignored. - - Returns: - The scaled loss by scaling factor, the list of optimize ops, and a - list of scaled parameters and gradients. - """ - opt_dict = self._optimizer.__class__.__dict__ - if 'minimize' in opt_dict and isinstance(opt_dict['minimize'], - types.FunctionType): - warnings.warn( - "The decorated optimizer has its own `minimize` method, but it will not be executed." - ) - - scaled_params_grads = self.backward( - loss, - startup_program=startup_program, - parameter_list=parameter_list, - no_grad_set=no_grad_set) - - optimize_ops = self.apply_optimize(loss, startup_program, - scaled_params_grads) - - return optimize_ops, scaled_params_grads - - -def decorate_bf16(optimizer, amp_lists=None, use_bf16_guard=None): - """ - Decorate the given optimizer to adapt to the mixed-precision training. - - Args: - optimizer(Optimizer): A common Optimizer. - amp_lists (CustomOpLists): An CustomOpLists object. - use_bf16_guard(bool): Whether to use `bf16_guard` when constructing the program. - Default None, which means that its value equals to `use_pure_bf16`. - - Returns: - An optimizer acting like a normal one but with mixed-precision training - enabled. - - Examples 1: - .. code-block:: python - - # black&white list based strategy example - import paddle - import paddle.static as static - - paddle.enable_static() - - data = static.data(name='X', shape=[None, 1], dtype='float32') - hidden = static.nn.fc(x=data, size=10) - loss = paddle.mean(hidden) - optimizer = paddle.optimizer.Adam(learning_rate=0.001) - - mp_optimizer = static.amp.decorate(optimizer=optimizer) - - ops, param_grads = mp_optimizer.minimize(loss) - - Examples 2: - .. code-block:: python - - # pure bf16 training example - import numpy as np - import paddle - import paddle.nn.functional as F - - def run_example_code(): - exe = paddle.static.Executor(place) - data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32') - conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3) - # 1) Use bf16_guard to control the range of bf16 kernels used. - with paddle.static.amp.bf16_guard(): - bn = paddle.static.nn.batch_norm(input=conv2d, act="relu") - pool = F.max_pool2d(bn, kernel_size=2, stride=2) - hidden = paddle.static.nn.fc(pool, size=10) - loss = paddle.mean(hidden) - # 2) Create the optimizer and set `multi_precision` to True. - # Setting `multi_precision` to True can avoid the poor accuracy - # or the slow convergence in a way. - optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True) - # 3) These ops in `custom_black_list` will keep in the float32 computation type. - amp_list = paddle.static.amp.CustomOpLists( - custom_black_list=['pool2d']) - # 4) The entry of Paddle AMP. - optimizer = paddle.static.amp.decorate( - optimizer, - amp_list) - # If you don't use the default_startup_program(), you should pass - # your defined `startup_program` into `minimize`. - optimizer.minimize(loss) - exe.run(paddle.static.default_startup_program()) - # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). - # If you want to perform the testing process, you should pass `test_program` into `amp_init`. - optimizer.amp_init() - - """ - if amp_lists is None: - amp_lists = AutoMixedPrecisionListsBF16() - - mp_optimizer = OptimizerWithMixedPrecision(optimizer, amp_lists, - use_bf16_guard) - - return mp_optimizer diff --git a/python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py b/python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py deleted file mode 100644 index 36f25758ae4c8..0000000000000 --- a/python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import paddle -import paddle.fluid as fluid -import contextlib -import numpy -import unittest - -paddle.enable_static() - - -def train(is_local): - x = fluid.layers.data(name='x', shape=[13], dtype='float32') - y = fluid.layers.data(name='y', shape=[1], dtype='float32') - - with paddle.static.amp.bf16_guard(): - y_predict = fluid.layers.fc(input=x, size=1, act=None) - - cost = fluid.layers.square_error_cost(input=y_predict, label=y) - avg_cost = fluid.layers.mean(cost) - - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) - sgd_optimizer = paddle.static.amp.decorate_bf16( - sgd_optimizer, use_bf16_guard=True) - sgd_optimizer.minimize(avg_cost) - - BATCH_SIZE = 20 - - train_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.uci_housing.train(), buf_size=500), - batch_size=BATCH_SIZE) - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - - def train_loop(main_program): - feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) - exe.run(fluid.default_startup_program()) - sgd_optimizer.amp_init(exe.place) - - PASS_NUM = 1 - for pass_id in range(PASS_NUM): - for data in train_reader(): - avg_loss_value, = exe.run(main_program, - feed=feeder.feed(data), - fetch_list=[avg_cost]) - print(avg_loss_value) - - if is_local: - train_loop(fluid.default_main_program()) - - -def infer(save_dirname=None): - if save_dirname is None: - return - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - - inference_scope = fluid.core.Scope() - with fluid.scope_guard(inference_scope): - # Use fluid.io.load_inference_model to obtain the inference program desc, - # the feed_target_names (the names of variables that will be fed - # data using feed operators), and the fetch_targets (variables that - # we want to obtain data from using fetch operators). - [inference_program, feed_target_names, - fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) - - # The input's dimension should be 2-D and the second dim is 13 - # The input data should be >= 0 - batch_size = 10 - - test_reader = paddle.batch( - paddle.dataset.uci_housing.test(), batch_size=batch_size) - - test_data = next(test_reader()) - test_feat = numpy.array( - [data[0] for data in test_data]).astype("float32") - test_label = numpy.array( - [data[1] for data in test_data]).astype("float32") - - assert feed_target_names[0] == 'x' - results = exe.run(inference_program, - feed={feed_target_names[0]: numpy.array(test_feat)}, - fetch_list=fetch_targets) - print("infer shape: ", results[0].shape) - print("infer results: ", results[0]) - print("ground truth: ", test_label) - - -def main(): - if not fluid.core.is_compiled_with_mkldnn(): - return - - # Directory for saving the trained model - save_dirname = "fit_a_line.inference.model" - - train(save_dirname) - infer(save_dirname) - - -class TestFitALine(unittest.TestCase): - def test_cpu(self): - with self.program_scope_guard(): - main() - - @contextlib.contextmanager - def program_scope_guard(self): - prog = fluid.Program() - startup_prog = fluid.Program() - scope = fluid.core.Scope() - with fluid.scope_guard(scope): - with fluid.program_guard(prog, startup_prog): - yield - - -if __name__ == '__main__': - unittest.main() From 3c9e45f567dd159b7821362a5a92a6720a4ff2f2 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Mon, 1 Mar 2021 16:49:11 +0100 Subject: [PATCH 10/33] Changes for CI --- .../contrib/mixed_precision/bf16_utils.py | 289 +++++++++--------- 1 file changed, 143 insertions(+), 146 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py index 88d1cca5ac699..1ad21fb721d36 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py @@ -26,7 +26,10 @@ import logging import numpy as np -__all__ = ["bf16_guard", "convert_float_to_uint16", "convert_uint16_to_float"] +__all__ = [ + "bf16_guard", "cast_model_to_bf16", "convert_float_to_uint16", + "convert_uint16_to_float" +] _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') @@ -196,32 +199,6 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, return num_cast_ops -def find_true_prev_op(ops, cur_op, var_name): - """ - Find the true prev op that outputs var_name variable. - - Args: - ops (list): A list of ops. - cur_op (Operator): Current operator which has var_name variable. - var_name (string): Variable name. - """ - prev_op = [] - for op in ops: - if op == cur_op: - break - for out_name in op.output_names: - for out_var_name in op.output(out_name): - if out_var_name == var_name: - prev_op.append(op) - if prev_op: - if not len(prev_op) == 1: - raise ValueError("There must be only one previous op " - "that outputs {0} variable".format(var_name)) - else: - return prev_op[0] - return None - - def find_true_post_op(ops, cur_op, var_name): """ if there are post ops, return them, if there is no post op, @@ -246,27 +223,6 @@ def find_true_post_op(ops, cur_op, var_name): return post_op -def find_op_index(block_desc, cur_op_desc): - """ - """ - for idx in range(block_desc.op_size()): - if cur_op_desc == block_desc.op(idx): - return idx - return -1 - - -def _is_in_black_varnames(op, amp_lists): - for in_name in op.input_arg_names: - if in_name in amp_lists.black_varnames: - return True - - for out_name in op.output_arg_names: - if out_name in amp_lists.black_varnames: - return True - - return False - - def _need_keep_fp32(op, unsupported_op_list, use_bf16_guard): if op.type in unsupported_op_list: # the highest priority condition: If ops don't have bf16 computing kernels, @@ -320,108 +276,149 @@ def bf16_guard(): yield -def rewrite_program(main_prog, amp_lists, use_bf16_guard): +def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True): """ - Traverse all ops in current block and insert cast op according to - which set current op belongs to. - - 1. When an op belongs to the black list, add it to black set - 2. When an op belongs to the white list, add it to white set - 3. When an op belongs to the gray list. If one - of its inputs is the output of black set op or black list op, - add it to black set. If all of its previous ops are not black - op and one of its inputs is the output of white set op or - white list op, add it to white set. - 4. When an op isn't in the lists, add it to black op set. - 5. Add necessary cast ops to make sure that black set op will be - computed in fp32 mode, while white set op will be computed in - bf16 mode. - + Traverse all ops in the whole model and set their inputs and outputs + to the bf16 data type. This function will do some special process for + the batch normalization, which keeps the computational process of + batchnorms in FP32. Args: - main_prog (Program): The main program for training. + program (Program): The used program. + amp_lists (AutoMixedPrecisionListsBF16): An AutoMixedPrecisionListsBF16 object. + use_bf16_guard(bool): Determine whether to use `bf16_guard` when + constructing the program. Default True. """ - block = main_prog.global_block() - ops = block.ops - white_op_set = set() - black_op_set = set() - for op in ops: - - # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder, - # we don't need to handle reader op and the input of 'create_py_reader' is not - # in block, which may result in errors. - # See GeneratorLoader._init_non_iterable() for details. - if op.type == 'create_py_reader' or op.type == 'read': - continue - - if amp_lists.black_varnames is not None and _is_in_black_varnames( - op, amp_lists): - black_op_set.add(op) - continue - - if op.type in amp_lists.black_list: - black_op_set.add(op) - elif op.type in amp_lists.white_list: - white_op_set.add(op) - elif op.type in amp_lists.gray_list: - is_black_op = False - is_white_op = False + + if amp_lists is None: + amp_lists = AutoMixedPrecisionListsBF16() + global_block = program.global_block() + keep_fp32_ops = set() + to_bf16_var_names = set() + to_bf16_pre_cast_ops = set() + origin_ops = [] + for block in program.blocks: + origin_ops.extend(block.ops) + + for block in program.blocks: + ops = block.ops + for op in ops: + if op.type == 'create_py_reader' or op.type == 'read': + continue + if _need_keep_fp32(op, amp_lists.unsupported_list, use_bf16_guard): + keep_fp32_ops.add(op) + continue # processed below for in_name in op.input_names: - # if this op has inputs - if in_name: - for in_var_name in op.input(in_name): + if op.type in { + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + } and in_name not in {'X', 'Z'}: + continue + for in_var_name in op.input(in_name): + in_var = None + try: in_var = block.var(in_var_name) - # this in_var isn't the output of other op - if in_var.op is None: - continue - elif in_var.op is op: - prev_op = find_true_prev_op(ops, op, in_var_name) - if prev_op is None: - continue + except ValueError as e: + _logger.debug( + "-- {}, try to get it in the global block --". + format(e)) + in_var = global_block.var(in_var_name) + if in_var is not None: + _logger.debug( + "-- var {} is got in the global block --". + format(in_var_name)) + + if in_var is None or in_var.type not in _valid_types: + continue + + if in_var.dtype == core.VarDesc.VarType.FP32: + if in_var.is_data: + to_bf16_pre_cast_ops.add(op) else: - prev_op = in_var.op - # if it's one of inputs - if prev_op in black_op_set or \ - prev_op.type in amp_lists.black_list: - is_black_op = True - elif prev_op in white_op_set or \ - prev_op.type in amp_lists.white_list: - is_white_op = True - if is_black_op: - black_op_set.add(op) - elif is_white_op: - white_op_set.add(op) - else: - pass - else: - # For numerical safe, we apply fp32 computation on ops that - # are not determined which list they should stay. - black_op_set.add(op) - - idx = 0 - while idx < len(ops): - op = ops[idx] - num_cast_ops = 0 - if op in black_op_set: - num_cast_ops = _insert_cast_op(block, op, idx, - core.VarDesc.VarType.BF16, - core.VarDesc.VarType.FP32) - elif op in white_op_set: - if use_bf16_guard: - if not (op.has_attr('op_namescope') and - (_bf16_guard_pattern in op.attr("op_namescope"))): - idx += 1 - continue - if op.has_attr('use_mkldnn'): - op._set_attr('use_mkldnn', True) - op._set_attr('mkldnn_data_type', 'bfloat16') - elif op.has_attr('dtype') and op.attr( - 'dtype') == core.VarDesc.VarType.FP32: - op._set_attr('dtype', core.VarDesc.VarType.BF16) - - num_cast_ops = _insert_cast_op(block, op, idx, - core.VarDesc.VarType.FP32, - core.VarDesc.VarType.BF16) - else: - pass + in_var.desc.set_dtype(core.VarDesc.VarType.BF16) + to_bf16_var_names.add(in_var_name) + + _logger.debug( + "-- op type: {}, in var name: {}, in var dtype: {} --". + format(op.type, in_var_name, in_var.dtype)) - idx += num_cast_ops + 1 + for out_name in op.output_names: + if op.type in { + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + } and out_name != 'Y': + continue + for out_var_name in op.output(out_name): + out_var = None + try: + out_var = block.var(out_var_name) + except ValueError as e: + _logger.debug( + "-- {}, try to get it in the global block --". + format(e)) + out_var = global_block.var(out_var_name) + if out_var is not None: + _logger.debug( + "-- var {} is got in the global block --". + format(out_var_name)) + + if out_var is None or out_var.type not in _valid_types: + continue + + if out_var.dtype == core.VarDesc.VarType.FP32: + out_var.desc.set_dtype(core.VarDesc.VarType.BF16) + + _logger.debug( + "-- op type: {}, out var name: {}, out var dtype: {} --". + format(op.type, out_var_name, out_var.dtype)) + if op.has_attr('in_dtype') and op.attr( + 'in_dtype') == core.VarDesc.VarType.FP32: + op._set_attr('in_dtype', core.VarDesc.VarType.BF16) + if op.has_attr('out_dtype') and op.attr( + 'out_dtype') == core.VarDesc.VarType.FP32: + op._set_attr('out_dtype', core.VarDesc.VarType.BF16) + if op.has_attr('dtype') and op.attr( + 'dtype') == core.VarDesc.VarType.FP32: + op._set_attr('dtype', core.VarDesc.VarType.BF16) + if op.has_attr('use_mkldnn'): + op._set_attr('use_mkldnn', True) + op._set_attr('mkldnn_data_type', 'bfloat16') + + # process ops in keep_fp32_ops + op_var_rename_map = [ + collections.OrderedDict() for _ in range(len(program.blocks)) + ] + for block in program.blocks: + ops = block.ops + idx = 0 + while idx < len(ops): + op = ops[idx] + num_cast_ops = 0 + if op not in keep_fp32_ops: + if op in to_bf16_pre_cast_ops: + in_var_cast_num = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.BF16) + num_cast_ops += in_var_cast_num + else: + pre_cast_num = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.BF16, + core.VarDesc.VarType.FP32) + num_cast_ops += pre_cast_num + for out_var_name in op.output_arg_names: + out_var = block.vars.get(out_var_name) + if out_var is None or out_var.type not in _valid_types: + continue + if out_var.dtype == core.VarDesc.VarType.BF16: + out_var.desc.set_dtype(core.VarDesc.VarType.FP32) + post_ops = find_true_post_op(ops, op, out_var_name) + for post_op in post_ops: + if post_op in keep_fp32_ops: + continue + post_cast_num = _insert_cast_post_op( + block, op, idx + pre_cast_num + 1, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.BF16, out_var_name, + op_var_rename_map) + num_cast_ops += post_cast_num + idx += num_cast_ops + 1 + + _rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops) + return to_bf16_var_names From a6a951817237b72aad9b021961e72c2643a66774 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Mon, 1 Mar 2021 19:35:46 +0100 Subject: [PATCH 11/33] Changes for CI --- .../paddle/fluid/contrib/mixed_precision/bf16_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py index 1ad21fb721d36..609f6f2c6f7e6 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py @@ -223,6 +223,15 @@ def find_true_post_op(ops, cur_op, var_name): return post_op +def find_op_index(block_desc, cur_op_desc): + """ + """ + for idx in range(block_desc.op_size()): + if cur_op_desc == block_desc.op(idx): + return idx + return -1 + + def _need_keep_fp32(op, unsupported_op_list, use_bf16_guard): if op.type in unsupported_op_list: # the highest priority condition: If ops don't have bf16 computing kernels, From 8c1b7b72f09ee7e7c54907dbc7c719444c5c7a51 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Wed, 3 Mar 2021 16:36:39 +0100 Subject: [PATCH 12/33] Improvements --- .../contrib/mixed_precision/bf16_utils.py | 294 ++++++++---------- .../contrib/tests/test_model_cast_to_bf16.py | 7 +- 2 files changed, 134 insertions(+), 167 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py index 609f6f2c6f7e6..8fa64aece24f7 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py @@ -27,7 +27,7 @@ import numpy as np __all__ = [ - "bf16_guard", "cast_model_to_bf16", "convert_float_to_uint16", + "bf16_guard", "rewrite_program_bf16", "convert_float_to_uint16", "convert_uint16_to_float" ] @@ -167,36 +167,30 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): return num_cast_ops -def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, - op_var_rename_map): - num_cast_ops = 0 - - target_var = block.var(target_name) - if target_var.type not in _valid_types or target_var.dtype == dest_dtype: - return num_cast_ops - - assert target_var.dtype == src_dtype, \ - "The real dtype({}) is not equal to the src dtype({})".format(_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype)) - - cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype) - cast_var = block.vars.get(cast_name) - if cast_var is None or cast_var.dtype != dest_dtype: - cast_var = block.create_var( - name=cast_name, - dtype=dest_dtype, - persistable=False, - stop_gradient=target_var.stop_gradient) - block._insert_op( - idx, - type="cast", - inputs={"X": target_var}, - outputs={"Out": cast_var}, - attrs={"in_dtype": target_var.dtype, - "out_dtype": cast_var.dtype}) - num_cast_ops += 1 - op_var_rename_map[block.idx][target_var.name] = cast_var.name +def find_true_prev_op(ops, cur_op, var_name): + """ + Find the true prev op that outputs var_name variable. - return num_cast_ops + Args: + ops (list): A list of ops. + cur_op (Operator): Current operator which has var_name variable. + var_name (string): Variable name. + """ + prev_op = [] + for op in ops: + if op == cur_op: + break + for out_name in op.output_names: + for out_var_name in op.output(out_name): + if out_var_name == var_name: + prev_op.append(op) + if prev_op: + if not len(prev_op) == 1: + raise ValueError("There must be only one previous op " + "that outputs {0} variable".format(var_name)) + else: + return prev_op[0] + return None def find_true_post_op(ops, cur_op, var_name): @@ -232,6 +226,18 @@ def find_op_index(block_desc, cur_op_desc): return -1 +def _is_in_black_varnames(op, amp_lists): + for in_name in op.input_arg_names: + if in_name in amp_lists.black_varnames: + return True + + for out_name in op.output_arg_names: + if out_name in amp_lists.black_varnames: + return True + + return False + + def _need_keep_fp32(op, unsupported_op_list, use_bf16_guard): if op.type in unsupported_op_list: # the highest priority condition: If ops don't have bf16 computing kernels, @@ -285,149 +291,111 @@ def bf16_guard(): yield -def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True): +def rewrite_program_bf16(main_prog, amp_lists=None, use_bf16_guard=False): """ - Traverse all ops in the whole model and set their inputs and outputs - to the bf16 data type. This function will do some special process for - the batch normalization, which keeps the computational process of - batchnorms in FP32. + Traverse all ops in current block and insert cast op according to + which set current op belongs to. + + 1. When an op belongs to the black list, add it to black set + 2. When an op belongs to the white list, add it to white set + 3. When an op belongs to the gray list. If one + of its inputs is the output of black set op or black list op, + add it to black set. If all of its previous ops are not black + op and one of its inputs is the output of white set op or + white list op, add it to white set. + 4. When an op isn't in the lists, add it to black op set. + 5. Add necessary cast ops to make sure that black set op will be + computed in fp32 mode, while white set op will be computed in + bf16 mode. + Args: - program (Program): The used program. - amp_lists (AutoMixedPrecisionListsBF16): An AutoMixedPrecisionListsBF16 object. - use_bf16_guard(bool): Determine whether to use `bf16_guard` when - constructing the program. Default True. + main_prog (Program): The main program for training. """ - if amp_lists is None: amp_lists = AutoMixedPrecisionListsBF16() - global_block = program.global_block() - keep_fp32_ops = set() - to_bf16_var_names = set() - to_bf16_pre_cast_ops = set() - origin_ops = [] - for block in program.blocks: - origin_ops.extend(block.ops) - - for block in program.blocks: - ops = block.ops - for op in ops: - if op.type == 'create_py_reader' or op.type == 'read': - continue - if _need_keep_fp32(op, amp_lists.unsupported_list, use_bf16_guard): - keep_fp32_ops.add(op) - continue # processed below + block = main_prog.global_block() + ops = block.ops + white_op_set = set() + black_op_set = set() + for op in ops: + + # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder, + # we don't need to handle reader op and the input of 'create_py_reader' is not + # in block, which may result in errors. + # See GeneratorLoader._init_non_iterable() for details. + if op.type == 'create_py_reader' or op.type == 'read': + continue + + if amp_lists.black_varnames is not None and _is_in_black_varnames( + op, amp_lists): + black_op_set.add(op) + continue + + if op.type in amp_lists.black_list or _need_keep_fp32( + op, amp_lists.unsupported_list, use_bf16_guard): + black_op_set.add(op) + elif op.type in amp_lists.white_list: + white_op_set.add(op) + elif op.type in amp_lists.gray_list: + is_black_op = False + is_white_op = False for in_name in op.input_names: - if op.type in { - 'batch_norm', 'fused_bn_add_activation', 'layer_norm' - } and in_name not in {'X', 'Z'}: - continue - for in_var_name in op.input(in_name): - in_var = None - try: + # if this op has inputs + if in_name: + for in_var_name in op.input(in_name): in_var = block.var(in_var_name) - except ValueError as e: - _logger.debug( - "-- {}, try to get it in the global block --". - format(e)) - in_var = global_block.var(in_var_name) - if in_var is not None: - _logger.debug( - "-- var {} is got in the global block --". - format(in_var_name)) - - if in_var is None or in_var.type not in _valid_types: - continue - - if in_var.dtype == core.VarDesc.VarType.FP32: - if in_var.is_data: - to_bf16_pre_cast_ops.add(op) + # this in_var isn't the output of other op + if in_var.op is None: + continue + elif in_var.op is op: + prev_op = find_true_prev_op(ops, op, in_var_name) + if prev_op is None: + continue else: - in_var.desc.set_dtype(core.VarDesc.VarType.BF16) - to_bf16_var_names.add(in_var_name) - - _logger.debug( - "-- op type: {}, in var name: {}, in var dtype: {} --". - format(op.type, in_var_name, in_var.dtype)) - - for out_name in op.output_names: - if op.type in { - 'batch_norm', 'fused_bn_add_activation', 'layer_norm' - } and out_name != 'Y': + prev_op = in_var.op + # if it's one of inputs + if prev_op in black_op_set or \ + prev_op.type in amp_lists.black_list: + is_black_op = True + elif prev_op in white_op_set or \ + prev_op.type in amp_lists.white_list: + is_white_op = True + if is_black_op: + black_op_set.add(op) + elif is_white_op: + white_op_set.add(op) + else: + pass + else: + # For numerical safe, we apply fp32 computation on ops that + # are not determined which list they should stay. + black_op_set.add(op) + + idx = 0 + while idx < len(ops): + op = ops[idx] + num_cast_ops = 0 + if op in black_op_set: + num_cast_ops = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.BF16, + core.VarDesc.VarType.FP32) + elif op in white_op_set: + if use_bf16_guard: + if not (op.has_attr('op_namescope') and + (_bf16_guard_pattern in op.attr("op_namescope"))): + idx += 1 continue - for out_var_name in op.output(out_name): - out_var = None - try: - out_var = block.var(out_var_name) - except ValueError as e: - _logger.debug( - "-- {}, try to get it in the global block --". - format(e)) - out_var = global_block.var(out_var_name) - if out_var is not None: - _logger.debug( - "-- var {} is got in the global block --". - format(out_var_name)) - - if out_var is None or out_var.type not in _valid_types: - continue - - if out_var.dtype == core.VarDesc.VarType.FP32: - out_var.desc.set_dtype(core.VarDesc.VarType.BF16) - - _logger.debug( - "-- op type: {}, out var name: {}, out var dtype: {} --". - format(op.type, out_var_name, out_var.dtype)) - if op.has_attr('in_dtype') and op.attr( - 'in_dtype') == core.VarDesc.VarType.FP32: - op._set_attr('in_dtype', core.VarDesc.VarType.BF16) - if op.has_attr('out_dtype') and op.attr( - 'out_dtype') == core.VarDesc.VarType.FP32: - op._set_attr('out_dtype', core.VarDesc.VarType.BF16) - if op.has_attr('dtype') and op.attr( - 'dtype') == core.VarDesc.VarType.FP32: - op._set_attr('dtype', core.VarDesc.VarType.BF16) if op.has_attr('use_mkldnn'): op._set_attr('use_mkldnn', True) op._set_attr('mkldnn_data_type', 'bfloat16') + elif op.has_attr('dtype') and op.attr( + 'dtype') == core.VarDesc.VarType.FP32: + op._set_attr('dtype', core.VarDesc.VarType.BF16) - # process ops in keep_fp32_ops - op_var_rename_map = [ - collections.OrderedDict() for _ in range(len(program.blocks)) - ] - for block in program.blocks: - ops = block.ops - idx = 0 - while idx < len(ops): - op = ops[idx] - num_cast_ops = 0 - if op not in keep_fp32_ops: - if op in to_bf16_pre_cast_ops: - in_var_cast_num = _insert_cast_op(block, op, idx, - core.VarDesc.VarType.FP32, - core.VarDesc.VarType.BF16) - num_cast_ops += in_var_cast_num - else: - pre_cast_num = _insert_cast_op(block, op, idx, - core.VarDesc.VarType.BF16, - core.VarDesc.VarType.FP32) - num_cast_ops += pre_cast_num - for out_var_name in op.output_arg_names: - out_var = block.vars.get(out_var_name) - if out_var is None or out_var.type not in _valid_types: - continue - if out_var.dtype == core.VarDesc.VarType.BF16: - out_var.desc.set_dtype(core.VarDesc.VarType.FP32) - post_ops = find_true_post_op(ops, op, out_var_name) - for post_op in post_ops: - if post_op in keep_fp32_ops: - continue - post_cast_num = _insert_cast_post_op( - block, op, idx + pre_cast_num + 1, - core.VarDesc.VarType.FP32, - core.VarDesc.VarType.BF16, out_var_name, - op_var_rename_map) - num_cast_ops += post_cast_num - idx += num_cast_ops + 1 - - _rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops) - return to_bf16_var_names + num_cast_ops = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.BF16) + else: + pass + + idx += num_cast_ops + 1 diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py index 7c29e2fd732c4..3246c454548a6 100644 --- a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -21,7 +21,7 @@ import numpy as np import paddle.fluid.layers as layers from paddle.fluid import core -from paddle.fluid.contrib.mixed_precision.bf16_utils import cast_model_to_bf16,\ +from paddle.fluid.contrib.mixed_precision.bf16_utils import rewrite_program_bf16,\ convert_float_to_uint16, convert_uint16_to_float paddle.enable_static() @@ -58,7 +58,7 @@ def get_static_graph_result(self, feed, fetch_list, with_lod=False): exe = fluid.Executor(core.CPUPlace()) exe.run(fluid.default_startup_program()) prog = fluid.default_main_program() - cast_model_to_bf16(prog, use_bf16_guard=True) + rewrite_program_bf16(prog, use_bf16_guard=True) return exe.run(prog, feed=feed, fetch_list=fetch_list, @@ -95,8 +95,7 @@ def test_elementwise_math(self): }, fetch_list=[ret_bf16, ret]) - stt = convert_uint16_to_float(static_ret_bf16) - self.assertTrue(np.allclose(stt, static_ret, 1e-2)) + self.assertTrue(np.allclose(static_ret_bf16, static_ret, 1e-2)) if __name__ == '__main__': From 654904b1fc884d7b3d94a34945a4775c98874b08 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Wed, 10 Mar 2021 10:21:32 +0100 Subject: [PATCH 13/33] Refactor --- .../fluid/contrib/mixed_precision/__init__.py | 9 +- .../contrib/mixed_precision/bf16/__init__.py | 24 ++ .../contrib/mixed_precision/bf16/amp_lists.py | 279 ++++++++++++++++++ .../{bf16_utils.py => bf16/amp_utils.py} | 111 +++---- .../contrib/mixed_precision/bf16_lists.py | 46 --- .../contrib/mixed_precision/fp16_lists.py | 33 +-- .../fluid/contrib/tests/test_bf16_utils.py | 163 +++++----- .../contrib/tests/test_model_cast_to_bf16.py | 18 +- python/paddle/fluid/layers/nn.py | 6 +- 9 files changed, 439 insertions(+), 250 deletions(-) create mode 100644 python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py create mode 100644 python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py rename python/paddle/fluid/contrib/mixed_precision/{bf16_utils.py => bf16/amp_utils.py} (80%) delete mode 100644 python/paddle/fluid/contrib/mixed_precision/bf16_lists.py diff --git a/python/paddle/fluid/contrib/mixed_precision/__init__.py b/python/paddle/fluid/contrib/mixed_precision/__init__.py index d246e04949ece..5bc6333531a51 100644 --- a/python/paddle/fluid/contrib/mixed_precision/__init__.py +++ b/python/paddle/fluid/contrib/mixed_precision/__init__.py @@ -18,15 +18,12 @@ from .decorator import * from . import fp16_lists from .fp16_lists import * -from . import bf16_lists -from .bf16_lists import * from . import fp16_utils from .fp16_utils import * -from . import bf16_utils -from .bf16_utils import * +from . import bf16 +from .bf16 import * __all__ = decorator.__all__ __all__ += fp16_lists.__all__ -__all__ += bf16_lists.__all__ __all__ += fp16_utils.__all__ -__all__ += bf16_utils.__all__ +__all__ += ['bf16'] diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py b/python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py new file mode 100644 index 0000000000000..8c05bc4899cf7 --- /dev/null +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from . import amp_lists +from .amp_lists import * +from . import amp_utils +from .amp_utils import * + +__all__ = [] +__all__ += amp_lists.__all__ +__all__ += amp_utils.__all__ diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py new file mode 100644 index 0000000000000..94a1561863875 --- /dev/null +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py @@ -0,0 +1,279 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +__all__ = ["AutoMixedPrecisionLists"] + + +class AutoMixedPrecisionLists(object): + """ + AutoMixedPrecisionLists is a class for fp32/bf16 list. It can update + pre-defined fp32 list and bf16 list according to users' custom fp32 + bf16 lists. The lists are used for an algorithm which determines op's + execution mode (fp32 or bf16). + + Args: + custom_bf16_list (set): Users' custom bf16 list. + custom_fp32_list (set): Users' custom fp32 list. + custom_fp32_varnames (set): Users' custom fp32 variables' names. + """ + + def __init__(self, + custom_bf16_list=None, + custom_fp32_list=None, + custom_fp32_varnames=None): + self._custom_bf16_list = custom_bf16_list + self._custom_fp32_list = custom_fp32_list + self.bf16_list = copy.copy(bf16_list) + self.fp32_list = copy.copy(fp32_list) + self.gray_list = copy.copy(gray_list) + self.unsupported_list = copy.copy(unsupported_list) + self.fp32_varnames = copy.copy(custom_fp32_varnames) + self._update_list() + + def _update_list(self): + """ + Update fp32 and bf16 list according to users' custom list. + """ + if self._custom_bf16_list and self._custom_fp32_list: + for op_name in self._custom_bf16_list: + if op_name in self._custom_fp32_list: + raise ValueError("Custom bf16 list overlap " + "custom fp32 list") + if self._custom_bf16_list: + for op_name in self._custom_bf16_list: + if op_name in self.fp32_list: + self.fp32_list.remove(op_name) + elif op_name in self.gray_list: + self.gray_list.remove(op_name) + self.bf16_list.add(op_name) + if self._custom_fp32_list: + for op_name in self._custom_fp32_list: + if op_name in self.bf16_list: + self.bf16_list.remove(op_name) + elif op_name in self.gray_list: + self.gray_list.remove(op_name) + self.fp32_list.add(op_name) + self.unsupported_list.add(op_name) + + +# always bf16 +bf16_list = {'elementwise_add', } + +# depends on the prev_op type +gray_list = {'reshape2', } + +# always fp32 +fp32_list = { + 'conv2d', + 'matmul', + 'matmul_v2', + 'mul', + 'exp', + 'square', + 'log', + 'mean', + 'sum', + 'cos_sim', + 'softmax', + 'softmax_with_cross_entropy', + 'sigmoid_cross_entropy_with_logits', + 'cross_entropy', + 'cross_entropy2', + 'lookup_table', + 'lookup_table_v2', + # 'elementwise_add', + 'elementwise_sub', + 'elementwise_mul', + 'elementwise_div', + 'elementwise_max', + 'elementwise_min', + 'elementwise_pow', + 'elementwise_mod', + 'elementwise_floordiv', + 'batch_norm', + 'layer_norm', + 'tanh', + 'sigmoid', + 'top_k', + 'pool2d', + 'pool3d', + 'dropout', + 'relu', + 'relu6', + 'leaky_relu', + 'soft_relu', + 'flatten2', + 'stack', + 'unstack', + 'uniform_random', + 'uniform_random_batch_size_like', + 'gaussian_random', + 'gaussian_random_batch_size_like', + 'slice', + 'rank', + 'scale', + 'transpose2', + # 'reshape2', + 'gather', + 'fill_constant', + 'get_tensor_from_selected_rows', + 'sign', + 'cast', + 'fused_bn_add_activation', +} + +# The set of ops that don't support bf16 calculation +unsupported_list = { + # from python/paddle/fluid/layers/io.py + 'send', + 'send_barrier', + 'recv', + 'fetch_barrier', + 'create_py_reader', + 'create_double_buffer_reader', + 'read', + 'load', + + # from python/paddle/fluid/control_flow.py + 'increment', + 'less_than', + 'less_equal', + 'greater_than', + 'greater_equal', + 'equal', + 'not_equal', + 'read_from_array', + 'shrink_rnn_memory', + 'lod_array_length', + 'logical_and', + 'logical_or', + 'logical_xor', + 'logical_not', + 'print', + 'conditional_block', + 'while', + 'ifelse', + 'is_empty', + 'lstm', + 'cudnn_lstm', + 'lstmp', + 'gru', + 'gru_unit', + 'linear_chain_crf', + 'crf_decoding', + 'bpr_loss', + 'chunk_eval', + 'sequence_conv', + 'sequence_softmax', + # Depthwise conv2d isn't fast and safe currently. + # ref: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h#L79 + 'depthwise_conv2d', + # Tensor Core kernels are not available for 3D convolutions currently. + 'conv3d', + 'sequence_pool', + 'sequence_concat', + 'sequence_slice', + 'data_norm', + 'group_norm', + 'spectral_norm', + 'depthwise_conv2d_transpose', + 'sequence_expand', + 'conv_transposed2d', + 'conv_transposed3d', + 'sequence_expand_as', + 'sequence_pad', + 'sequence_unpad', + 'sequence_erase', + 'beam_search', + 'beam_search_decode', + 'lstm_unit', + 'reduce_sum', + 'reduce_mean', + 'reduce_max', + 'reduce_min', + 'reduce_prod', + 'reduce_all', + 'reduce_any', + 'split', + 'edit_distance', + 'ctc_align', + 'warpctc', + 'sequence_reshape', + 'nce', + 'hierarchical_sigmoid', + 'im2sequence', + 'row_conv', + 'multiplex', + 'sample_logits', + 'one_hot', + 'smooth_l1_loss', + 'squeeze2', + 'unsqueeze2', + 'lod_reset', + 'lrn', + 'pad', + 'pad_constant_like', + 'label_smooth', + 'scatter', + 'sequence_scatter', + 'random_crop', + 'mean_iou', + 'selu', + 'crop', + 'affine_grid', + 'rank_loss', + 'margin_rank_loss', + 'pad2d', + 'elu', + 'pow', + 'stanh', + 'hard_sigmoid', + 'swish', + 'prelu', + 'brelu', + 'sequence_enumerate', + 'sequence_mask', + 'expand', + 'sampling_id', + 'maxout', + 'space_to_depth', + 'sequence_reverse', + 'similarity_focus', + 'hash', + 'grid_sampler', + 'log_loss', + 'teacher_student_sigmoid_loss', + 'add_position_encoding', + 'bilinear_tensor_product', + 'shuffle_channel', + 'temporal_shift', + 'psroi_pool', + 'huber_loss', + 'kldiv_loss', + 'tree_conv', + 'pixel_shuffle', + 'fsp', + 'cvm', + 'affine_channel', + 'roi_pool', + 'roi_align', + 'anchor_generator', + 'generate_proposals', + 'generate_proposal_labels', + 'generate_mask_labels', + 'lookup_table', + 'lookup_table_v2', +} diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py similarity index 80% rename from python/paddle/fluid/contrib/mixed_precision/bf16_utils.py rename to python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py index 8fa64aece24f7..1fd70a85c92f7 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py @@ -3,6 +3,7 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at + # # http://www.apache.org/licenses/LICENSE-2.0 # @@ -15,21 +16,15 @@ from __future__ import print_function import struct -from ... import core -from ... import framework -from ... import layers -from ... import global_scope -from ...log_helper import get_logger -from ...wrapped_decorator import signature_safe_contextmanager -from .bf16_lists import AutoMixedPrecisionListsBF16 -import collections +from .... import core +from .... import framework +from ....log_helper import get_logger +from ....wrapped_decorator import signature_safe_contextmanager +from .amp_lists import AutoMixedPrecisionLists import logging import numpy as np -__all__ = [ - "bf16_guard", "rewrite_program_bf16", "convert_float_to_uint16", - "convert_uint16_to_float" -] +__all__ = ["bf16_guard", "rewrite_program_bf16", "convert_float_to_uint16"] _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') @@ -50,14 +45,6 @@ def convert_float_to_uint16(in_list): return np.reshape(out, in_list.shape) -def convert_uint16_to_float(in_list): - in_list = np.asarray(in_list) - out = np.vectorize( - lambda x: struct.unpack(' Date: Wed, 10 Mar 2021 11:27:51 +0100 Subject: [PATCH 14/33] Refactor --- python/paddle/fluid/contrib/mixed_precision/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/__init__.py b/python/paddle/fluid/contrib/mixed_precision/__init__.py index 5bc6333531a51..60c26542be93d 100644 --- a/python/paddle/fluid/contrib/mixed_precision/__init__.py +++ b/python/paddle/fluid/contrib/mixed_precision/__init__.py @@ -20,8 +20,6 @@ from .fp16_lists import * from . import fp16_utils from .fp16_utils import * -from . import bf16 -from .bf16 import * __all__ = decorator.__all__ __all__ += fp16_lists.__all__ From be8e75993191e4da7cdaf1c20063cf18cd5a8997 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Wed, 10 Mar 2021 12:46:44 +0100 Subject: [PATCH 15/33] Changes for CI --- .../fluid/contrib/mixed_precision/__init__.py | 2 ++ .../contrib/mixed_precision/bf16/amp_lists.py | 6 +++--- .../contrib/mixed_precision/bf16/amp_utils.py | 4 ++-- .../fluid/contrib/tests/test_bf16_utils.py | 18 +++++++++--------- python/paddle/fluid/dygraph/amp/auto_cast.py | 4 ++-- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/__init__.py b/python/paddle/fluid/contrib/mixed_precision/__init__.py index 60c26542be93d..5bc6333531a51 100644 --- a/python/paddle/fluid/contrib/mixed_precision/__init__.py +++ b/python/paddle/fluid/contrib/mixed_precision/__init__.py @@ -20,6 +20,8 @@ from .fp16_lists import * from . import fp16_utils from .fp16_utils import * +from . import bf16 +from .bf16 import * __all__ = decorator.__all__ __all__ += fp16_lists.__all__ diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py index 94a1561863875..6a4840f272ce5 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py @@ -14,12 +14,12 @@ import copy -__all__ = ["AutoMixedPrecisionLists"] +__all__ = ["AutoMixedPrecisionListsBF16"] -class AutoMixedPrecisionLists(object): +class AutoMixedPrecisionListsBF16(object): """ - AutoMixedPrecisionLists is a class for fp32/bf16 list. It can update + AutoMixedPrecisionListsBF16 is a class for fp32/bf16 list. It can update pre-defined fp32 list and bf16 list according to users' custom fp32 bf16 lists. The lists are used for an algorithm which determines op's execution mode (fp32 or bf16). diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py index 1fd70a85c92f7..f64312240e9e0 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py @@ -20,7 +20,7 @@ from .... import framework from ....log_helper import get_logger from ....wrapped_decorator import signature_safe_contextmanager -from .amp_lists import AutoMixedPrecisionLists +from .amp_lists import AutoMixedPrecisionListsBF16 import logging import numpy as np @@ -287,7 +287,7 @@ def rewrite_program_bf16(main_prog, amp_lists=None, use_bf16_guard=False): main_prog (Program): The main program for training. """ if amp_lists is None: - amp_lists = AutoMixedPrecisionLists() + amp_lists = AutoMixedPrecisionListsBF16() block = main_prog.global_block() ops = block.ops bf16_op_set = set() diff --git a/python/paddle/fluid/contrib/tests/test_bf16_utils.py b/python/paddle/fluid/contrib/tests/test_bf16_utils.py index e2ef23c99f8d0..8cfe7f9accd20 100644 --- a/python/paddle/fluid/contrib/tests/test_bf16_utils.py +++ b/python/paddle/fluid/contrib/tests/test_bf16_utils.py @@ -17,7 +17,7 @@ from paddle.fluid import core from paddle.fluid.contrib.mixed_precision.bf16 import amp_utils from paddle.fluid.contrib.mixed_precision.bf16 import amp_lists -from paddle.fluid.contrib.mixed_precision.bf16 import AutoMixedPrecisionLists +from paddle.fluid.contrib.mixed_precision.bf16 import AutoMixedPrecisionListsBF16 import paddle paddle.enable_static() @@ -29,7 +29,7 @@ def test_amp_lists(self): fp32_list = copy.copy(amp_lists.fp32_list) gray_list = copy.copy(amp_lists.gray_list) - amp_lists_ = AutoMixedPrecisionLists() + amp_lists_ = AutoMixedPrecisionListsBF16() self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) self.assertEqual(amp_lists_.gray_list, gray_list) @@ -43,7 +43,7 @@ def test_amp_lists_1(self): bf16_list.add('exp') fp32_list.remove('exp') - amp_lists_ = AutoMixedPrecisionLists({'exp'}) + amp_lists_ = AutoMixedPrecisionListsBF16({'exp'}) self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) self.assertEqual(amp_lists_.gray_list, gray_list) @@ -57,7 +57,7 @@ def test_amp_lists_2(self): fp32_list.remove('tanh') bf16_list.add('tanh') - amp_lists_ = AutoMixedPrecisionLists({'tanh'}) + amp_lists_ = AutoMixedPrecisionListsBF16({'tanh'}) self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) self.assertEqual(amp_lists_.gray_list, gray_list) @@ -70,7 +70,7 @@ def test_amp_lists_3(self): # 3. w={'lstm'}, b=None bf16_list.add('lstm') - amp_lists_ = AutoMixedPrecisionLists({'lstm'}) + amp_lists_ = AutoMixedPrecisionListsBF16({'lstm'}) self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) self.assertEqual(amp_lists_.gray_list, gray_list) @@ -84,7 +84,7 @@ def test_amp_lists_4(self): bf16_list.remove('elementwise_add') fp32_list.add('elementwise_add') - amp_lists_ = AutoMixedPrecisionLists( + amp_lists_ = AutoMixedPrecisionListsBF16( custom_fp32_list={'elementwise_add'}) self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) @@ -99,7 +99,7 @@ def test_amp_lists_5(self): fp32_list.add('elementwise_add') bf16_list.remove('elementwise_add') - amp_lists_ = AutoMixedPrecisionLists( + amp_lists_ = AutoMixedPrecisionListsBF16( custom_fp32_list={'elementwise_add'}) self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) @@ -113,7 +113,7 @@ def test_amp_lists_6(self): # 6. w=None, b={'lstm'} fp32_list.add('lstm') - amp_lists_ = AutoMixedPrecisionLists(custom_fp32_list={'lstm'}) + amp_lists_ = AutoMixedPrecisionListsBF16(custom_fp32_list={'lstm'}) self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) self.assertEqual(amp_lists_.gray_list, gray_list) @@ -121,7 +121,7 @@ def test_amp_lists_6(self): def test_amp_lists_7(self): # 7. w={'lstm'} b={'lstm'} # raise ValueError - self.assertRaises(ValueError, AutoMixedPrecisionLists, {'lstm'}, + self.assertRaises(ValueError, AutoMixedPrecisionListsBF16, {'lstm'}, {'lstm'}) def test_find_op_index(self): diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 4ff08337875c0..2207bdfc83505 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -60,8 +60,8 @@ } -#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list -# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode. +#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionListsBF16._update_list +# The reason why not use AutoMixedPrecisionListsBF16 is that custom_black_varnames is not suitable for imperative mode. def _update_list(custom_white_list, custom_black_list): """ Update black and white list according to users' custom list. From fc665a529a5c40c4e8c4030ffa24258b528c85b0 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Wed, 10 Mar 2021 15:23:15 +0100 Subject: [PATCH 16/33] Changes for CI --- .../fluid/contrib/mixed_precision/__init__.py | 4 +- .../fluid/contrib/tests/test_bf16_utils.py | 69 +++++++++---------- .../contrib/tests/test_model_cast_to_bf16.py | 8 +-- python/paddle/static/amp/__init__.py | 2 + python/setup.py.in | 1 + 5 files changed, 43 insertions(+), 41 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/__init__.py b/python/paddle/fluid/contrib/mixed_precision/__init__.py index 5bc6333531a51..04dfc629ffda3 100644 --- a/python/paddle/fluid/contrib/mixed_precision/__init__.py +++ b/python/paddle/fluid/contrib/mixed_precision/__init__.py @@ -20,10 +20,10 @@ from .fp16_lists import * from . import fp16_utils from .fp16_utils import * -from . import bf16 +# from . import bf16 from .bf16 import * __all__ = decorator.__all__ __all__ += fp16_lists.__all__ __all__ += fp16_utils.__all__ -__all__ += ['bf16'] +__all__ += bf16.__all__ diff --git a/python/paddle/fluid/contrib/tests/test_bf16_utils.py b/python/paddle/fluid/contrib/tests/test_bf16_utils.py index 8cfe7f9accd20..b26588f1376d2 100644 --- a/python/paddle/fluid/contrib/tests/test_bf16_utils.py +++ b/python/paddle/fluid/contrib/tests/test_bf16_utils.py @@ -14,10 +14,9 @@ import copy import unittest import paddle.fluid as fluid +import paddle.fluid.contrib.mixed_precision as amp from paddle.fluid import core -from paddle.fluid.contrib.mixed_precision.bf16 import amp_utils -from paddle.fluid.contrib.mixed_precision.bf16 import amp_lists -from paddle.fluid.contrib.mixed_precision.bf16 import AutoMixedPrecisionListsBF16 +from paddle.fluid.contrib.mixed_precision import AutoMixedPrecisionListsBF16 import paddle paddle.enable_static() @@ -25,95 +24,95 @@ class AMPTest(unittest.TestCase): def test_amp_lists(self): - bf16_list = copy.copy(amp_lists.bf16_list) - fp32_list = copy.copy(amp_lists.fp32_list) - gray_list = copy.copy(amp_lists.gray_list) + bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) + fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) + gray_list = copy.copy(amp.bf16.amp_lists.gray_list) - amp_lists_ = AutoMixedPrecisionListsBF16() + amp_lists_ = amp.AutoMixedPrecisionListsBF16() self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) self.assertEqual(amp_lists_.gray_list, gray_list) def test_amp_lists_1(self): - bf16_list = copy.copy(amp_lists.bf16_list) - fp32_list = copy.copy(amp_lists.fp32_list) - gray_list = copy.copy(amp_lists.gray_list) + bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) + fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) + gray_list = copy.copy(amp.bf16.amp_lists.gray_list) # 1. w={'exp}, b=None bf16_list.add('exp') fp32_list.remove('exp') - amp_lists_ = AutoMixedPrecisionListsBF16({'exp'}) + amp_lists_ = amp.AutoMixedPrecisionListsBF16({'exp'}) self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) self.assertEqual(amp_lists_.gray_list, gray_list) def test_amp_lists_2(self): - bf16_list = copy.copy(amp_lists.bf16_list) - fp32_list = copy.copy(amp_lists.fp32_list) - gray_list = copy.copy(amp_lists.gray_list) + bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) + fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) + gray_list = copy.copy(amp.bf16.amp_lists.gray_list) # 2. w={'tanh'}, b=None fp32_list.remove('tanh') bf16_list.add('tanh') - amp_lists_ = AutoMixedPrecisionListsBF16({'tanh'}) + amp_lists_ = amp.AutoMixedPrecisionListsBF16({'tanh'}) self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) self.assertEqual(amp_lists_.gray_list, gray_list) def test_amp_lists_3(self): - bf16_list = copy.copy(amp_lists.bf16_list) - fp32_list = copy.copy(amp_lists.fp32_list) - gray_list = copy.copy(amp_lists.gray_list) + bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) + fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) + gray_list = copy.copy(amp.bf16.amp_lists.gray_list) # 3. w={'lstm'}, b=None bf16_list.add('lstm') - amp_lists_ = AutoMixedPrecisionListsBF16({'lstm'}) + amp_lists_ = amp.AutoMixedPrecisionListsBF16({'lstm'}) self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) self.assertEqual(amp_lists_.gray_list, gray_list) def test_amp_lists_4(self): - bf16_list = copy.copy(amp_lists.bf16_list) - fp32_list = copy.copy(amp_lists.fp32_list) - gray_list = copy.copy(amp_lists.gray_list) + bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) + fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) + gray_list = copy.copy(amp.bf16.amp_lists.gray_list) # 4. w=None, b={'elementwise_add'} bf16_list.remove('elementwise_add') fp32_list.add('elementwise_add') - amp_lists_ = AutoMixedPrecisionListsBF16( + amp_lists_ = amp.AutoMixedPrecisionListsBF16( custom_fp32_list={'elementwise_add'}) self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) self.assertEqual(amp_lists_.gray_list, gray_list) def test_amp_lists_5(self): - bf16_list = copy.copy(amp_lists.bf16_list) - fp32_list = copy.copy(amp_lists.fp32_list) - gray_list = copy.copy(amp_lists.gray_list) + bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) + fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) + gray_list = copy.copy(amp.bf16.amp_lists.gray_list) # 5. w=None, b={'elementwise_add'} fp32_list.add('elementwise_add') bf16_list.remove('elementwise_add') - amp_lists_ = AutoMixedPrecisionListsBF16( + amp_lists_ = amp.AutoMixedPrecisionListsBF16( custom_fp32_list={'elementwise_add'}) self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) self.assertEqual(amp_lists_.gray_list, gray_list) def test_amp_lists_6(self): - bf16_list = copy.copy(amp_lists.bf16_list) - fp32_list = copy.copy(amp_lists.fp32_list) - gray_list = copy.copy(amp_lists.gray_list) + bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) + fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) + gray_list = copy.copy(amp.bf16.amp_lists.gray_list) # 6. w=None, b={'lstm'} fp32_list.add('lstm') - amp_lists_ = AutoMixedPrecisionListsBF16(custom_fp32_list={'lstm'}) + amp_lists_ = amp.AutoMixedPrecisionListsBF16(custom_fp32_list={'lstm'}) self.assertEqual(amp_lists_.bf16_list, bf16_list) self.assertEqual(amp_lists_.fp32_list, fp32_list) self.assertEqual(amp_lists_.gray_list, gray_list) @@ -121,13 +120,13 @@ def test_amp_lists_6(self): def test_amp_lists_7(self): # 7. w={'lstm'} b={'lstm'} # raise ValueError - self.assertRaises(ValueError, AutoMixedPrecisionListsBF16, {'lstm'}, - {'lstm'}) + self.assertRaises(ValueError, amp.AutoMixedPrecisionListsBF16, + {'lstm'}, {'lstm'}) def test_find_op_index(self): block = fluid.default_main_program().global_block() op_desc = core.OpDesc() - idx = amp_utils.find_op_index(block.desc, op_desc) + idx = amp.bf16.amp_utils.find_op_index(block.desc, op_desc) assert (idx == -1) def test_find_true_post_op(self): @@ -140,7 +139,7 @@ def test_find_true_post_op(self): type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]}) op2 = block.append_op( type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]}) - res = amp_utils.find_true_post_op(block.ops, op1, "Y") + res = amp.bf16.amp_utils.find_true_post_op(block.ops, op1, "Y") assert (res == [op2]) diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py index 55244b09ade6d..8cc48608228aa 100644 --- a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -20,8 +20,8 @@ import unittest import numpy as np import paddle.fluid.layers as layers +import paddle.fluid.contrib.mixed_precision as amp from paddle.fluid import core -from paddle.fluid.contrib.mixed_precision.bf16 import rewrite_program_bf16, convert_float_to_uint16 paddle.enable_static() @@ -57,7 +57,7 @@ def get_static_graph_result(self, feed, fetch_list, with_lod=False): exe = fluid.Executor(core.CPUPlace()) exe.run(fluid.default_startup_program()) prog = fluid.default_main_program() - rewrite_program_bf16(prog, use_bf16_guard=True) + amp.rewrite_program_bf16(prog, use_bf16_guard=True) return exe.run(prog, feed=feed, fetch_list=fetch_list, @@ -68,8 +68,8 @@ def test_elementwise_math(self): n = np.ones([size, size], dtype='float32') * 3.2 nn = np.ones([size, size], dtype='float32') * -2.7 - n_bf16 = convert_float_to_uint16(n) - nn_bf16 = convert_float_to_uint16(nn) + n_bf16 = amp.convert_float_to_uint16(n) + nn_bf16 = amp.convert_float_to_uint16(nn) with self.static_graph(): t_bf16 = layers.data( diff --git a/python/paddle/static/amp/__init__.py b/python/paddle/static/amp/__init__.py index 604c7c3d2b490..4a634490f34c6 100644 --- a/python/paddle/static/amp/__init__.py +++ b/python/paddle/static/amp/__init__.py @@ -14,5 +14,7 @@ from ...fluid.contrib import mixed_precision from ...fluid.contrib.mixed_precision import * +from ...fluid.contrib.mixed_precision import bf16 __all__ = mixed_precision.__all__ +__all__ += mixed_precision.bf16.__all__ diff --git a/python/setup.py.in b/python/setup.py.in index 0afc3956a01e1..0fc95ec5c5322 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -179,6 +179,7 @@ packages=['paddle', 'paddle.fluid.contrib.utils', 'paddle.fluid.contrib.extend_optimizer', 'paddle.fluid.contrib.mixed_precision', + 'paddle.fluid.contrib.mixed_precision.bf16', 'paddle.fluid.contrib.layers', 'paddle.fluid.transpiler', 'paddle.fluid.transpiler.details', From 70987249f6da265a44f62e632bfd6a4f92684e36 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Wed, 10 Mar 2021 17:26:18 +0100 Subject: [PATCH 17/33] More tests --- .../contrib/mixed_precision/bf16/amp_lists.py | 6 ++ .../paddle/fluid/contrib/tests/CMakeLists.txt | 3 + .../contrib/tests/test_fit_a_line_bf16.py | 96 +++++++++++++++++++ .../contrib/tests/test_model_cast_to_bf16.py | 6 +- python/paddle/fluid/data_feeder.py | 4 +- 5 files changed, 111 insertions(+), 4 deletions(-) create mode 100644 python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py index 6a4840f272ce5..15c2292e3cdc1 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py @@ -28,6 +28,12 @@ class AutoMixedPrecisionListsBF16(object): custom_bf16_list (set): Users' custom bf16 list. custom_fp32_list (set): Users' custom fp32 list. custom_fp32_varnames (set): Users' custom fp32 variables' names. + + Examples: + .. code-block:: python + + with paddle.static.amp.bf16_guard(): + AutoMixedPrecisionListsBF16(custom_fp32_list={'lstm'}) """ def __init__(self, diff --git a/python/paddle/fluid/contrib/tests/CMakeLists.txt b/python/paddle/fluid/contrib/tests/CMakeLists.txt index 779cf33b6b8b9..6b76e51ce37e3 100644 --- a/python/paddle/fluid/contrib/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/tests/CMakeLists.txt @@ -2,6 +2,7 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") list(REMOVE_ITEM TEST_OPS test_multi_precision_fp16_train) +list(REMOVE_ITEM TEST_OPS test_fit_a_line_bf16) list(REMOVE_ITEM TEST_OPS test_model_cast_to_bf16) foreach(src ${TEST_OPS}) @@ -15,6 +16,8 @@ set_tests_properties(test_weight_decay_extend PROPERTIES TIMEOUT 120) set_tests_properties(test_multi_precision_fp16_train PROPERTIES TIMEOUT 120) if(WITH_MKLDNN) + py_test_modules(test_fit_a_line_bf16 MODULES test_fit_a_line_bf16) py_test_modules(test_model_cast_to_bf16 MODULES test_model_cast_to_bf16) set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 120) + set_tests_properties(test_fit_a_line_bf16 PROPERTIES TIMEOUT 120) endif() diff --git a/python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py b/python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py new file mode 100644 index 0000000000000..85e15d494190d --- /dev/null +++ b/python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py @@ -0,0 +1,96 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import paddle.static.amp as amp +import contextlib +import unittest +import math +import sys + +paddle.enable_static() + + +def train(): + x = fluid.layers.data(name='x', shape=[13], dtype='float32') + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + + with paddle.static.amp.bf16_guard(): + y_predict = fluid.layers.fc(input=x, size=1, act=None) + + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) + amp.rewrite_program_bf16(fluid.default_main_program(), use_bf16_guard=True) + sgd_optimizer.minimize(avg_cost) + + BATCH_SIZE = 20 + + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.uci_housing.train(), buf_size=500), + batch_size=BATCH_SIZE) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + def train_loop(main_program): + feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) + exe.run(fluid.default_startup_program()) + + PASS_NUM = 100 + for pass_id in range(PASS_NUM): + for data in train_reader(): + avg_loss_value, = exe.run(main_program, + feed=feeder.feed(data), + fetch_list=[avg_cost]) + print(avg_loss_value) + if avg_loss_value[0] < 10.0: + return + if math.isnan(float(avg_loss_value)): + sys.exit("got NaN loss, training failed.") + raise AssertionError("Fit a line cost is too large, {0:2.2}".format( + avg_loss_value[0])) + + train_loop(fluid.default_main_program()) + + +def main(): + if not fluid.core.is_compiled_with_mkldnn(): + return + + train() + + +class TestFitALine(unittest.TestCase): + def test_cpu(self): + with self.program_scope_guard(): + main() + + @contextlib.contextmanager + def program_scope_guard(self): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + yield + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py index 8cc48608228aa..51b0be2621e63 100644 --- a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -20,7 +20,7 @@ import unittest import numpy as np import paddle.fluid.layers as layers -import paddle.fluid.contrib.mixed_precision as amp +import paddle.static.amp as amp from paddle.fluid import core paddle.enable_static() @@ -83,12 +83,12 @@ def test_elementwise_math(self): ret = layers.elementwise_mul(ret, t) ret = fluid.layers.reshape(ret, [0, 0]) - with paddle.static.amp.bf16.bf16_guard(): + with amp.bf16_guard(): ret_bf16 = layers.elementwise_add(t_bf16, tt_bf16) ret_bf16 = layers.elementwise_mul(ret_bf16, t_bf16) ret_bf16 = layers.reshape(ret_bf16, [0, 0]) - with paddle.static.amp.bf16.bf16_guard(): + with amp.bf16_guard(): ret_fp32bf16 = layers.elementwise_add(t, tt) ret_fp32bf16 = layers.elementwise_mul(ret_fp32bf16, t) ret_fp32bf16 = layers.reshape(ret_fp32bf16, [0, 0]) diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index f693e250a4a14..6cc3f09fcb69b 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -126,7 +126,9 @@ def check_dtype(input_dtype, warnings.warn( "The data type of '%s' in %s only support float16 in GPU now. %s" % (input_name, op_name, extra_message)) - if convert_dtype(input_dtype) in ['uint16']: + if convert_dtype(input_dtype) in ['uint16'] and op_name not in [ + 'reshape', 'lookup_table' + ]: warnings.warn( "The data type of '%s' in %s only support bfloat16 in OneDNN now. %s" % (input_name, op_name, extra_message)) From ed6cd06ed4cf8d3ec55240f4cf1c9c835ec30be1 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Wed, 10 Mar 2021 19:40:27 +0100 Subject: [PATCH 18/33] More tests, introduced bf16 scale op --- paddle/fluid/operators/scale_op.cc | 2 + .../fluid/contrib/mixed_precision/__init__.py | 2 +- .../contrib/mixed_precision/bf16/amp_lists.py | 5 +- .../paddle/fluid/contrib/tests/CMakeLists.txt | 3 - .../contrib/tests/test_fit_a_line_bf16.py | 96 ------------------- python/paddle/fluid/data_feeder.py | 2 +- python/paddle/fluid/dygraph/amp/auto_cast.py | 4 +- python/paddle/fluid/layers/nn.py | 4 +- .../fluid/tests/book/test_fit_a_line.py | 15 ++- .../fluid/tests/book/test_word2vec_book.py | 29 ++++-- 10 files changed, 43 insertions(+), 119 deletions(-) delete mode 100644 python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 281689d3bdaff..a9b1f299dab82 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -128,6 +128,8 @@ REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, REGISTER_OP_CPU_KERNEL( scale, ops::ScaleKernel, ops::ScaleKernel, + ops::ScaleKernel, ops::ScaleKernel, ops::ScaleKernel, ops::ScaleKernel, diff --git a/python/paddle/fluid/contrib/mixed_precision/__init__.py b/python/paddle/fluid/contrib/mixed_precision/__init__.py index 04dfc629ffda3..571b755b50d2a 100644 --- a/python/paddle/fluid/contrib/mixed_precision/__init__.py +++ b/python/paddle/fluid/contrib/mixed_precision/__init__.py @@ -20,7 +20,7 @@ from .fp16_lists import * from . import fp16_utils from .fp16_utils import * -# from . import bf16 +from . import bf16 from .bf16 import * __all__ = decorator.__all__ diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py index 15c2292e3cdc1..81e6307da7430 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py @@ -31,9 +31,10 @@ class AutoMixedPrecisionListsBF16(object): Examples: .. code-block:: python - + import paddle + paddle.enable_static() with paddle.static.amp.bf16_guard(): - AutoMixedPrecisionListsBF16(custom_fp32_list={'lstm'}) + paddle.static.amp.AutoMixedPrecisionListsBF16(custom_fp32_list={'lstm'}) """ def __init__(self, diff --git a/python/paddle/fluid/contrib/tests/CMakeLists.txt b/python/paddle/fluid/contrib/tests/CMakeLists.txt index 6b76e51ce37e3..779cf33b6b8b9 100644 --- a/python/paddle/fluid/contrib/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/tests/CMakeLists.txt @@ -2,7 +2,6 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") list(REMOVE_ITEM TEST_OPS test_multi_precision_fp16_train) -list(REMOVE_ITEM TEST_OPS test_fit_a_line_bf16) list(REMOVE_ITEM TEST_OPS test_model_cast_to_bf16) foreach(src ${TEST_OPS}) @@ -16,8 +15,6 @@ set_tests_properties(test_weight_decay_extend PROPERTIES TIMEOUT 120) set_tests_properties(test_multi_precision_fp16_train PROPERTIES TIMEOUT 120) if(WITH_MKLDNN) - py_test_modules(test_fit_a_line_bf16 MODULES test_fit_a_line_bf16) py_test_modules(test_model_cast_to_bf16 MODULES test_model_cast_to_bf16) set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 120) - set_tests_properties(test_fit_a_line_bf16 PROPERTIES TIMEOUT 120) endif() diff --git a/python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py b/python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py deleted file mode 100644 index 85e15d494190d..0000000000000 --- a/python/paddle/fluid/contrib/tests/test_fit_a_line_bf16.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import paddle -import paddle.fluid as fluid -import paddle.static.amp as amp -import contextlib -import unittest -import math -import sys - -paddle.enable_static() - - -def train(): - x = fluid.layers.data(name='x', shape=[13], dtype='float32') - y = fluid.layers.data(name='y', shape=[1], dtype='float32') - - with paddle.static.amp.bf16_guard(): - y_predict = fluid.layers.fc(input=x, size=1, act=None) - - cost = fluid.layers.square_error_cost(input=y_predict, label=y) - avg_cost = fluid.layers.mean(cost) - - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) - amp.rewrite_program_bf16(fluid.default_main_program(), use_bf16_guard=True) - sgd_optimizer.minimize(avg_cost) - - BATCH_SIZE = 20 - - train_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.uci_housing.train(), buf_size=500), - batch_size=BATCH_SIZE) - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - - def train_loop(main_program): - feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) - exe.run(fluid.default_startup_program()) - - PASS_NUM = 100 - for pass_id in range(PASS_NUM): - for data in train_reader(): - avg_loss_value, = exe.run(main_program, - feed=feeder.feed(data), - fetch_list=[avg_cost]) - print(avg_loss_value) - if avg_loss_value[0] < 10.0: - return - if math.isnan(float(avg_loss_value)): - sys.exit("got NaN loss, training failed.") - raise AssertionError("Fit a line cost is too large, {0:2.2}".format( - avg_loss_value[0])) - - train_loop(fluid.default_main_program()) - - -def main(): - if not fluid.core.is_compiled_with_mkldnn(): - return - - train() - - -class TestFitALine(unittest.TestCase): - def test_cpu(self): - with self.program_scope_guard(): - main() - - @contextlib.contextmanager - def program_scope_guard(self): - prog = fluid.Program() - startup_prog = fluid.Program() - scope = fluid.core.Scope() - with fluid.scope_guard(scope): - with fluid.program_guard(prog, startup_prog): - yield - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 6cc3f09fcb69b..52be7493cf229 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -127,7 +127,7 @@ def check_dtype(input_dtype, "The data type of '%s' in %s only support float16 in GPU now. %s" % (input_name, op_name, extra_message)) if convert_dtype(input_dtype) in ['uint16'] and op_name not in [ - 'reshape', 'lookup_table' + 'reshape', 'lookup_table', 'scale' ]: warnings.warn( "The data type of '%s' in %s only support bfloat16 in OneDNN now. %s" diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 2207bdfc83505..4ff08337875c0 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -60,8 +60,8 @@ } -#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionListsBF16._update_list -# The reason why not use AutoMixedPrecisionListsBF16 is that custom_black_varnames is not suitable for imperative mode. +#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list +# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode. def _update_list(custom_white_list, custom_black_list): """ Update black and white list according to users' custom list. diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 998582dac29e3..00d1db19fc2f5 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -11430,8 +11430,8 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): return dygraph_utils._append_activation_in_dygraph(out) check_variable_and_dtype(x, "x", [ - 'float16', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', - 'uint8' + 'float16', 'uint16', 'float32', 'float64', 'int8', 'int16', 'int32', + 'int64', 'uint8' ], "scale") inputs = {'X': [x]} attrs = { diff --git a/python/paddle/fluid/tests/book/test_fit_a_line.py b/python/paddle/fluid/tests/book/test_fit_a_line.py index 9a2cc4ab1a1b9..926366557637f 100644 --- a/python/paddle/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/fluid/tests/book/test_fit_a_line.py @@ -26,7 +26,7 @@ paddle.enable_static() -def train(use_cuda, save_dirname, is_local): +def train(use_cuda, save_dirname, is_local, use_bf16): x = fluid.layers.data(name='x', shape=[13], dtype='float32') y_predict = fluid.layers.fc(input=x, size=1, act=None) @@ -37,6 +37,8 @@ def train(use_cuda, save_dirname, is_local): avg_cost = fluid.layers.mean(cost) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) + if use_bf16: + paddle.static.amp.rewrite_program_bf16(fluid.default_main_program()) sgd_optimizer.minimize(avg_cost) BATCH_SIZE = 20 @@ -133,14 +135,17 @@ def infer(use_cuda, save_dirname=None): print("ground truth: ", test_label) -def main(use_cuda, is_local=True): +def main(use_cuda, is_local=True, use_bf16=False): if use_cuda and not fluid.core.is_compiled_with_cuda(): return + if use_bf16 and not fluid.core.is_compiled_with_mkldnn(): + return + # Directory for saving the trained model save_dirname = "fit_a_line.inference.model" - train(use_cuda, save_dirname, is_local) + train(use_cuda, save_dirname, is_local, use_bf16) infer(use_cuda, save_dirname) @@ -153,6 +158,10 @@ def test_cuda(self): with self.program_scope_guard(): main(use_cuda=True) + def test_bf16(self): + with self.program_scope_guard(): + main(use_cuda=False, use_bf16=True) + @contextlib.contextmanager def program_scope_guard(self): prog = fluid.Program() diff --git a/python/paddle/fluid/tests/book/test_word2vec_book.py b/python/paddle/fluid/tests/book/test_word2vec_book.py index e33b1cc514aa6..ad7550fa9dd96 100644 --- a/python/paddle/fluid/tests/book/test_word2vec_book.py +++ b/python/paddle/fluid/tests/book/test_word2vec_book.py @@ -39,7 +39,12 @@ def get_place(target): format(target)) -def train(target, is_sparse, is_parallel, save_dirname, is_local=True): +def train(target, + is_sparse, + is_parallel, + save_dirname, + is_local=True, + use_bf16=False): PASS_NUM = 100 EMBED_SIZE = 32 HIDDEN_SIZE = 256 @@ -101,6 +106,8 @@ def __network__(words): raise NotImplementedError() sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) + if use_bf16: + paddle.static.amp.rewrite_program_bf16(fluid.default_main_program()) sgd_optimizer.minimize(avg_cost) train_reader = paddle.batch( @@ -239,12 +246,15 @@ def to_infer_tensor(lod_tensor): assert np.isclose(a, b, rtol=5e-5), "a: {}, b: {}".format(a, b) -def main(target, is_sparse, is_parallel): +def main(target, is_sparse, is_parallel, use_bf16): if target == "cuda" and not fluid.core.is_compiled_with_cuda(): return if target == "xpu" and not fluid.core.is_compiled_with_xpu(): return + if use_bf16 and not fluid.core.is_compiled_with_mkldnn(): + return + if not is_parallel: save_dirname = "word2vec.inference.model" else: @@ -255,7 +265,7 @@ def main(target, is_sparse, is_parallel): # so only inference is turned on. train("cpu", is_sparse, is_parallel, save_dirname) else: - train(target, is_sparse, is_parallel, save_dirname) + train(target, is_sparse, is_parallel, save_dirname, use_bf16=use_bf16) infer(target, save_dirname) @@ -268,10 +278,11 @@ class W2VTest(unittest.TestCase): pass -def inject_test_method(target, is_sparse, is_parallel): - fn_name = "test_{0}_{1}_{2}".format(target, "sparse" - if is_sparse else "dense", "parallel" - if is_parallel else "normal") +def inject_test_method(target, is_sparse, is_parallel, use_bf16=False): + fn_name = "test_{0}_{1}_{2}{3}".format(target, "sparse" + if is_sparse else "dense", "parallel" + if is_parallel else "normal", "_bf16" + if use_bf16 else "") def __impl__(*args, **kwargs): prog = fluid.Program() @@ -279,8 +290,7 @@ def __impl__(*args, **kwargs): scope = fluid.core.Scope() with fluid.scope_guard(scope): with fluid.program_guard(prog, startup_prog): - main( - target=target, is_sparse=is_sparse, is_parallel=is_parallel) + main(target, is_sparse, is_parallel, use_bf16) if (not fluid.core.is_compiled_with_cuda() or target == "cuda") and is_sparse: @@ -297,6 +307,7 @@ def __impl__(*args, **kwargs): for is_sparse in (False, True): for is_parallel in (False, ): inject_test_method(target, is_sparse, is_parallel) +inject_test_method("cpu", False, False, use_bf16=True) if __name__ == '__main__': unittest.main() From 8d65d4445f3e8eb07962a44c77b1c951fbfed30a Mon Sep 17 00:00:00 2001 From: arlesniak Date: Thu, 11 Mar 2021 15:17:44 +0100 Subject: [PATCH 19/33] Changes for CI --- .../contrib/mixed_precision/bf16/amp_utils.py | 77 +------------------ .../paddle/fluid/contrib/tests/CMakeLists.txt | 6 +- 2 files changed, 3 insertions(+), 80 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py index f64312240e9e0..d4873c407a170 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py @@ -21,6 +21,7 @@ from ....log_helper import get_logger from ....wrapped_decorator import signature_safe_contextmanager from .amp_lists import AutoMixedPrecisionListsBF16 +from ..fp16_utils import find_true_prev_op, find_true_post_op, _rename_arg, find_op_index import logging import numpy as np @@ -45,23 +46,6 @@ def convert_float_to_uint16(in_list): return np.reshape(out, in_list.shape) -def _rename_arg(op, old_name, new_name): - """ - If an op has old_name input and output, rename these input - args new_name. - - Args: - op (Operator): Current operator. - old_name (str): The old name of input args. - new_name (str): The new name of input args. - """ - op_desc = op.desc - if isinstance(op_desc, tuple): - op_desc = op_desc[0] - op_desc._rename_input(old_name, new_name) - op_desc._rename_output(old_name, new_name) - - def _dtype_to_str(dtype): """ Convert specific variable type to its corresponding string. @@ -142,65 +126,6 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): return num_cast_ops -def find_true_prev_op(ops, cur_op, var_name): - """ - Find the true prev op that outputs var_name variable. - - Args: - ops (list): A list of ops. - cur_op (Operator): Current operator which has var_name variable. - var_name (string): Variable name. - """ - prev_op = [] - for op in ops: - if op == cur_op: - break - for out_name in op.output_names: - for out_var_name in op.output(out_name): - if out_var_name == var_name: - prev_op.append(op) - if prev_op: - if not len(prev_op) == 1: - raise ValueError("There must be only one previous op " - "that outputs {0} variable".format(var_name)) - else: - return prev_op[0] - return None - - -def find_true_post_op(ops, cur_op, var_name): - """ - if there are post ops, return them, if there is no post op, - return None instead. - Args: - ops (list): A list of ops. - cur_op (Operator): Current operator which has var_name variable. - var_name (string): Variable name. - """ - post_op = [] - for idx, op in enumerate(ops): - if op == cur_op: - break - - for i in range(idx + 1, len(ops)): - op = ops[i] - for in_name in op.input_names: - for in_var_name in op.input(in_name): - if in_var_name == var_name: - post_op.append(op) - - return post_op - - -def find_op_index(block_desc, cur_op_desc): - """ - """ - for idx in range(block_desc.op_size()): - if cur_op_desc == block_desc.op(idx): - return idx - return -1 - - def _is_in_fp32_varnames(op, amp_lists): for in_name in op.input_arg_names: if in_name in amp_lists.fp32_varnames: diff --git a/python/paddle/fluid/contrib/tests/CMakeLists.txt b/python/paddle/fluid/contrib/tests/CMakeLists.txt index 779cf33b6b8b9..4f7b9f2df0da6 100644 --- a/python/paddle/fluid/contrib/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/tests/CMakeLists.txt @@ -14,7 +14,5 @@ set_tests_properties(test_image_classification_fp16 PROPERTIES TIMEOUT 120) set_tests_properties(test_weight_decay_extend PROPERTIES TIMEOUT 120) set_tests_properties(test_multi_precision_fp16_train PROPERTIES TIMEOUT 120) -if(WITH_MKLDNN) - py_test_modules(test_model_cast_to_bf16 MODULES test_model_cast_to_bf16) - set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 120) -endif() +py_test_modules(test_model_cast_to_bf16 MODULES test_model_cast_to_bf16) +set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 120) From f4f958b54a0d4997f6f72b6fc545ec9efe51d76a Mon Sep 17 00:00:00 2001 From: arlesniak Date: Thu, 11 Mar 2021 20:38:48 +0100 Subject: [PATCH 20/33] Changes for CI --- .../contrib/mixed_precision/bf16/amp_utils.py | 5 --- .../contrib/tests/test_model_cast_to_bf16.py | 42 +++++++++++++++++-- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py index d4873c407a170..c2c01f88c7431 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py @@ -280,11 +280,6 @@ def rewrite_program_bf16(main_prog, amp_lists=None, use_bf16_guard=False): core.VarDesc.VarType.BF16, core.VarDesc.VarType.FP32) elif op in bf16_op_set: - if use_bf16_guard: - if not (op.has_attr('op_namescope') and - (_bf16_guard_pattern in op.attr("op_namescope"))): - idx += 1 - continue if op.has_attr('use_mkldnn'): op._set_attr('use_mkldnn', True) op._set_attr('mkldnn_data_type', 'bfloat16') diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py index 51b0be2621e63..280173f9a164c 100644 --- a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -53,11 +53,13 @@ def scope_prog_guard(self): with fluid.program_guard(prog, startup_prog): yield - def get_static_graph_result(self, feed, fetch_list, with_lod=False): + def get_static_graph_result(self, feed, fetch_list, amp_fun, + with_lod=False): exe = fluid.Executor(core.CPUPlace()) exe.run(fluid.default_startup_program()) prog = fluid.default_main_program() - amp.rewrite_program_bf16(prog, use_bf16_guard=True) + if amp_fun is not None: + amp_fun(prog) return exe.run(prog, feed=feed, fetch_list=fetch_list, @@ -81,7 +83,7 @@ def test_elementwise_math(self): ret = layers.elementwise_add(t, tt) ret = layers.elementwise_mul(ret, t) - ret = fluid.layers.reshape(ret, [0, 0]) + ret = layers.reshape(ret, [0, 0]) with amp.bf16_guard(): ret_bf16 = layers.elementwise_add(t_bf16, tt_bf16) @@ -100,11 +102,43 @@ def test_elementwise_math(self): 't_bf16': n_bf16, 'tt_bf16': nn_bf16, }, - fetch_list=[ret_bf16, ret, ret_fp32bf16]) + fetch_list=[ret_bf16, ret, ret_fp32bf16], + amp_fun=lambda prog: amp.rewrite_program_bf16(prog, use_bf16_guard=True)) self.assertTrue(np.allclose(static_ret_bf16, static_ret, 1e-2)) self.assertTrue(np.allclose(static_ret_bf16, ret_fp32bf16, 1e-2)) + def test_op_rewrite(self): + size = 3 + n = np.ones([size, size], dtype='float32') * 3.2 + nn = np.ones([size, size], dtype='float32') * -2.7 + + with self.static_graph(): + t = layers.data(name='t', shape=[size, size], dtype='float32') + tt = layers.data(name='tt', shape=[size, size], dtype='float32') + + with amp.bf16_guard(): + ret = layers.elementwise_add(t, tt) + ret = layers.reshape(ret, [0, 0], act='elu') + ret = layers.elementwise_mul(ret, t) + ret = layers.elementwise_add(ret, tt) + + static_ret_bf16 = \ + self.get_static_graph_result( + feed={'t': n, 'tt': nn}, + fetch_list=[ret], + amp_fun=lambda prog: amp.rewrite_program_bf16( + prog, + amp.AutoMixedPrecisionListsBF16( + custom_fp32_varnames={'elementwise_mul'}, + ), + use_bf16_guard=True + ) + ) + self.assertTrue( + static_ret_bf16, np.ones( + [size, size], dtype='float32') * -1.1) + if __name__ == '__main__': unittest.main() From 6117982d826e61f03c94373b7f85f76cfcc54492 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Fri, 12 Mar 2021 12:14:33 +0100 Subject: [PATCH 21/33] Changes for CI --- python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py index 280173f9a164c..bc24e1851a589 100644 --- a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -130,7 +130,7 @@ def test_op_rewrite(self): amp_fun=lambda prog: amp.rewrite_program_bf16( prog, amp.AutoMixedPrecisionListsBF16( - custom_fp32_varnames={'elementwise_mul'}, + custom_fp32_varnames={'elementwise_mul_0.tmp_0'}, ), use_bf16_guard=True ) From 72405cf98962104566be78a278484aafc9eb777a Mon Sep 17 00:00:00 2001 From: arlesniak Date: Fri, 12 Mar 2021 14:56:47 +0100 Subject: [PATCH 22/33] Changes for CI --- python/paddle/static/amp/__init__.py | 3 ++- tools/static_mode_white_list.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/static/amp/__init__.py b/python/paddle/static/amp/__init__.py index 4a634490f34c6..bfc1beed55297 100644 --- a/python/paddle/static/amp/__init__.py +++ b/python/paddle/static/amp/__init__.py @@ -15,6 +15,7 @@ from ...fluid.contrib import mixed_precision from ...fluid.contrib.mixed_precision import * from ...fluid.contrib.mixed_precision import bf16 +from ...fluid.contrib.mixed_precision.bf16 import * __all__ = mixed_precision.__all__ -__all__ += mixed_precision.bf16.__all__ +__all__ += bf16.__all__ diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 2ea3f7654afda..6453eb48d7004 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -699,4 +699,5 @@ 'test_slice_op_xpu', 'test_generate_proposals_v2_op', 'test_lamb_op_xpu', + 'test_model_cast_to_bf16', ] From c213d0833d344e78030d35748c559e317eb204ff Mon Sep 17 00:00:00 2001 From: arlesniak Date: Wed, 17 Mar 2021 09:28:45 +0100 Subject: [PATCH 23/33] Changes for CI --- .../fluid/contrib/tests/test_bf16_utils.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/paddle/fluid/contrib/tests/test_bf16_utils.py b/python/paddle/fluid/contrib/tests/test_bf16_utils.py index b26588f1376d2..8c69f9bd06861 100644 --- a/python/paddle/fluid/contrib/tests/test_bf16_utils.py +++ b/python/paddle/fluid/contrib/tests/test_bf16_utils.py @@ -123,6 +123,34 @@ def test_amp_lists_7(self): self.assertRaises(ValueError, amp.AutoMixedPrecisionListsBF16, {'lstm'}, {'lstm'}) + def test_amp_lists_8(self): + bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) + fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) + gray_list = copy.copy(amp.bf16.amp_lists.gray_list) + + fp32_list.add('reshape2') + gray_list.remove('reshape2') + + amp_lists_ = amp.AutoMixedPrecisionListsBF16( + custom_fp32_list={'reshape2'}) + self.assertEqual(amp_lists_.bf16_list, bf16_list) + self.assertEqual(amp_lists_.fp32_list, fp32_list) + self.assertEqual(amp_lists_.gray_list, gray_list) + + def test_amp_list9_8(self): + bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) + fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) + gray_list = copy.copy(amp.bf16.amp_lists.gray_list) + + bf16_list.add('reshape2') + gray_list.remove('reshape2') + + amp_lists_ = amp.AutoMixedPrecisionListsBF16( + custom_bf16_list={'reshape2'}) + self.assertEqual(amp_lists_.bf16_list, bf16_list) + self.assertEqual(amp_lists_.fp32_list, fp32_list) + self.assertEqual(amp_lists_.gray_list, gray_list) + def test_find_op_index(self): block = fluid.default_main_program().global_block() op_desc = core.OpDesc() From b6c4ad2b0c7b8b990fcca892137c4bb852e1a094 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Wed, 17 Mar 2021 14:09:17 +0100 Subject: [PATCH 24/33] Changes for CI --- python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py index bc24e1851a589..850ff431863c4 100644 --- a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -130,8 +130,7 @@ def test_op_rewrite(self): amp_fun=lambda prog: amp.rewrite_program_bf16( prog, amp.AutoMixedPrecisionListsBF16( - custom_fp32_varnames={'elementwise_mul_0.tmp_0'}, - ), + custom_fp32_varnames={'elementwise_add_0.tmp_0'}), use_bf16_guard=True ) ) From 0bda415f6fa82f73e253d74a8a25bad70f4a2af3 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Wed, 17 Mar 2021 19:10:31 +0100 Subject: [PATCH 25/33] Changes for CI --- .../paddle/fluid/contrib/tests/test_bf16_utils.py | 2 +- .../fluid/contrib/tests/test_model_cast_to_bf16.py | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/contrib/tests/test_bf16_utils.py b/python/paddle/fluid/contrib/tests/test_bf16_utils.py index 8c69f9bd06861..afe9ccf3f22a6 100644 --- a/python/paddle/fluid/contrib/tests/test_bf16_utils.py +++ b/python/paddle/fluid/contrib/tests/test_bf16_utils.py @@ -137,7 +137,7 @@ def test_amp_lists_8(self): self.assertEqual(amp_lists_.fp32_list, fp32_list) self.assertEqual(amp_lists_.gray_list, gray_list) - def test_amp_list9_8(self): + def test_amp_list_9(self): bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) gray_list = copy.copy(amp.bf16.amp_lists.gray_list) diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py index 850ff431863c4..40ddcf2e66b75 100644 --- a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -65,7 +65,7 @@ def get_static_graph_result(self, feed, fetch_list, amp_fun, fetch_list=fetch_list, return_numpy=(not with_lod)) - def test_elementwise_math(self): + def test_graph_rewrite(self): size = 3 n = np.ones([size, size], dtype='float32') * 3.2 nn = np.ones([size, size], dtype='float32') * -2.7 @@ -108,11 +108,6 @@ def test_elementwise_math(self): self.assertTrue(np.allclose(static_ret_bf16, static_ret, 1e-2)) self.assertTrue(np.allclose(static_ret_bf16, ret_fp32bf16, 1e-2)) - def test_op_rewrite(self): - size = 3 - n = np.ones([size, size], dtype='float32') * 3.2 - nn = np.ones([size, size], dtype='float32') * -2.7 - with self.static_graph(): t = layers.data(name='t', shape=[size, size], dtype='float32') tt = layers.data(name='tt', shape=[size, size], dtype='float32') @@ -134,9 +129,9 @@ def test_op_rewrite(self): use_bf16_guard=True ) ) - self.assertTrue( - static_ret_bf16, np.ones( - [size, size], dtype='float32') * -1.1) + self.assertTrue( + static_ret_bf16, np.ones( + [size, size], dtype='float32') * -1.1) if __name__ == '__main__': From 68ea1f9f6dde9cd0fe04b1b1b89d76bfe98b31ba Mon Sep 17 00:00:00 2001 From: Artur Lesniak Date: Wed, 17 Mar 2021 22:28:47 +0100 Subject: [PATCH 26/33] Changes for CI --- python/paddle/fluid/contrib/tests/CMakeLists.txt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/paddle/fluid/contrib/tests/CMakeLists.txt b/python/paddle/fluid/contrib/tests/CMakeLists.txt index 4f7b9f2df0da6..a28588bfa5382 100644 --- a/python/paddle/fluid/contrib/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/tests/CMakeLists.txt @@ -2,7 +2,6 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") list(REMOVE_ITEM TEST_OPS test_multi_precision_fp16_train) -list(REMOVE_ITEM TEST_OPS test_model_cast_to_bf16) foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) @@ -13,6 +12,3 @@ py_test_modules(test_multi_precision_fp16_train MODULES test_multi_precision_fp1 set_tests_properties(test_image_classification_fp16 PROPERTIES TIMEOUT 120) set_tests_properties(test_weight_decay_extend PROPERTIES TIMEOUT 120) set_tests_properties(test_multi_precision_fp16_train PROPERTIES TIMEOUT 120) - -py_test_modules(test_model_cast_to_bf16 MODULES test_model_cast_to_bf16) -set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 120) From 09a2f473ee1009630042a9c2395a615e3ec7081b Mon Sep 17 00:00:00 2001 From: Artur Lesniak Date: Thu, 18 Mar 2021 06:20:13 +0100 Subject: [PATCH 27/33] Changes for CI --- python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py index 40ddcf2e66b75..d9088a8ebe74d 100644 --- a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -26,8 +26,6 @@ paddle.enable_static() -@unittest.skipIf(not core.supports_bfloat16(), - "place does not support BF16 evaluation") class TestModelCastBF16(unittest.TestCase): @classmethod def setUpClass(cls): From 274fa0bd631f5a9dfa6bafa879e4bf44a7e73418 Mon Sep 17 00:00:00 2001 From: Artur Lesniak Date: Thu, 18 Mar 2021 07:36:56 +0100 Subject: [PATCH 28/33] Changes for CI --- .../fluid/contrib/tests/test_bf16_utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/paddle/fluid/contrib/tests/test_bf16_utils.py b/python/paddle/fluid/contrib/tests/test_bf16_utils.py index afe9ccf3f22a6..a1eef560d88b7 100644 --- a/python/paddle/fluid/contrib/tests/test_bf16_utils.py +++ b/python/paddle/fluid/contrib/tests/test_bf16_utils.py @@ -157,6 +157,24 @@ def test_find_op_index(self): idx = amp.bf16.amp_utils.find_op_index(block.desc, op_desc) assert (idx == -1) + def test_is_in_fp32_varnames(self): + block = fluid.default_main_program().global_block() + + var1 = block.create_var(name="X", shape=[3], dtype='float32') + var2 = block.create_var(name="Y", shape=[3], dtype='float32') + var3 = block.create_var(name="Z", shape=[3], dtype='float32') + op1 = block.append_op( + type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]}) + op2 = block.append_op( + type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]}) + amp_lists_1 = amp.AutoMixedPrecisionListsBF16( + custom_fp32_varnames={'X'}) + assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_1) + amp_lists_2 = amp.AutoMixedPrecisionListsBF16( + custom_fp32_varnames={'Y'}) + assert amp.bf16.amp_utils._is_in_fp32_varnames(op2, amp_lists_2) + assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_2) + def test_find_true_post_op(self): block = fluid.default_main_program().global_block() From 11dc27896304aa4a7de1cfd75735aaaea0bef778 Mon Sep 17 00:00:00 2001 From: Artur Lesniak Date: Thu, 18 Mar 2021 09:03:15 +0100 Subject: [PATCH 29/33] Changes for CI --- python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py index d9088a8ebe74d..40ddcf2e66b75 100644 --- a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -26,6 +26,8 @@ paddle.enable_static() +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") class TestModelCastBF16(unittest.TestCase): @classmethod def setUpClass(cls): From bd2dea83e7e12b58eee1c4b3c65bf6e469136498 Mon Sep 17 00:00:00 2001 From: Artur Lesniak Date: Thu, 18 Mar 2021 12:18:08 +0100 Subject: [PATCH 30/33] Changes to trigger blocked CIs --- python/paddle/fluid/contrib/tests/test_bf16_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/fluid/contrib/tests/test_bf16_utils.py b/python/paddle/fluid/contrib/tests/test_bf16_utils.py index a1eef560d88b7..1f65aef540839 100644 --- a/python/paddle/fluid/contrib/tests/test_bf16_utils.py +++ b/python/paddle/fluid/contrib/tests/test_bf16_utils.py @@ -176,6 +176,7 @@ def test_is_in_fp32_varnames(self): assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_2) def test_find_true_post_op(self): + block = fluid.default_main_program().global_block() var1 = block.create_var(name="X", shape=[3], dtype='float32') From f07ca1507d13a3e09422c22ca8020db58ab45587 Mon Sep 17 00:00:00 2001 From: Artur Lesniak Date: Thu, 18 Mar 2021 14:24:00 +0100 Subject: [PATCH 31/33] Changes for CI --- python/paddle/fluid/tests/book/test_fit_a_line.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/tests/book/test_fit_a_line.py b/python/paddle/fluid/tests/book/test_fit_a_line.py index 926366557637f..df43d9366ff78 100644 --- a/python/paddle/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/fluid/tests/book/test_fit_a_line.py @@ -158,6 +158,8 @@ def test_cuda(self): with self.program_scope_guard(): main(use_cuda=True) + @unittest.skipIf(not fluid.core.supports_bfloat16(), + "place does not support BF16 evaluation") def test_bf16(self): with self.program_scope_guard(): main(use_cuda=False, use_bf16=True) From d8f810cc161b837464f529fac064a2093700c89f Mon Sep 17 00:00:00 2001 From: Artur Lesniak Date: Fri, 19 Mar 2021 11:29:39 +0100 Subject: [PATCH 32/33] Less lines in amp_lists.py --- .../contrib/mixed_precision/bf16/amp_lists.py | 212 +----------------- 1 file changed, 12 insertions(+), 200 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py index 81e6307da7430..216f55657b13f 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py @@ -13,6 +13,8 @@ # limitations under the License. import copy +from ..fp16_lists import white_list as white_list_fp16, black_list as black_list_fp16,\ + gray_list as gray_list_fp16, unsupported_fp16_list __all__ = ["AutoMixedPrecisionListsBF16"] @@ -80,207 +82,17 @@ def _update_list(self): bf16_list = {'elementwise_add', } # depends on the prev_op type -gray_list = {'reshape2', } - -# always fp32 -fp32_list = { - 'conv2d', - 'matmul', - 'matmul_v2', - 'mul', - 'exp', - 'square', - 'log', - 'mean', - 'sum', - 'cos_sim', - 'softmax', - 'softmax_with_cross_entropy', - 'sigmoid_cross_entropy_with_logits', - 'cross_entropy', - 'cross_entropy2', +gray_list = { + 'reshape2', 'lookup_table', - 'lookup_table_v2', - # 'elementwise_add', - 'elementwise_sub', - 'elementwise_mul', - 'elementwise_div', - 'elementwise_max', - 'elementwise_min', - 'elementwise_pow', - 'elementwise_mod', - 'elementwise_floordiv', - 'batch_norm', - 'layer_norm', - 'tanh', - 'sigmoid', - 'top_k', - 'pool2d', - 'pool3d', - 'dropout', - 'relu', - 'relu6', - 'leaky_relu', - 'soft_relu', - 'flatten2', - 'stack', - 'unstack', - 'uniform_random', - 'uniform_random_batch_size_like', - 'gaussian_random', - 'gaussian_random_batch_size_like', - 'slice', - 'rank', - 'scale', - 'transpose2', - # 'reshape2', - 'gather', - 'fill_constant', - 'get_tensor_from_selected_rows', - 'sign', - 'cast', - 'fused_bn_add_activation', } -# The set of ops that don't support bf16 calculation -unsupported_list = { - # from python/paddle/fluid/layers/io.py - 'send', - 'send_barrier', - 'recv', - 'fetch_barrier', - 'create_py_reader', - 'create_double_buffer_reader', - 'read', - 'load', +unsupported_list = unsupported_fp16_list.copy().copy() +fp32_list = black_list_fp16.copy().copy() +fp32_list |= white_list_fp16 +fp32_list |= gray_list_fp16 - # from python/paddle/fluid/control_flow.py - 'increment', - 'less_than', - 'less_equal', - 'greater_than', - 'greater_equal', - 'equal', - 'not_equal', - 'read_from_array', - 'shrink_rnn_memory', - 'lod_array_length', - 'logical_and', - 'logical_or', - 'logical_xor', - 'logical_not', - 'print', - 'conditional_block', - 'while', - 'ifelse', - 'is_empty', - 'lstm', - 'cudnn_lstm', - 'lstmp', - 'gru', - 'gru_unit', - 'linear_chain_crf', - 'crf_decoding', - 'bpr_loss', - 'chunk_eval', - 'sequence_conv', - 'sequence_softmax', - # Depthwise conv2d isn't fast and safe currently. - # ref: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h#L79 - 'depthwise_conv2d', - # Tensor Core kernels are not available for 3D convolutions currently. - 'conv3d', - 'sequence_pool', - 'sequence_concat', - 'sequence_slice', - 'data_norm', - 'group_norm', - 'spectral_norm', - 'depthwise_conv2d_transpose', - 'sequence_expand', - 'conv_transposed2d', - 'conv_transposed3d', - 'sequence_expand_as', - 'sequence_pad', - 'sequence_unpad', - 'sequence_erase', - 'beam_search', - 'beam_search_decode', - 'lstm_unit', - 'reduce_sum', - 'reduce_mean', - 'reduce_max', - 'reduce_min', - 'reduce_prod', - 'reduce_all', - 'reduce_any', - 'split', - 'edit_distance', - 'ctc_align', - 'warpctc', - 'sequence_reshape', - 'nce', - 'hierarchical_sigmoid', - 'im2sequence', - 'row_conv', - 'multiplex', - 'sample_logits', - 'one_hot', - 'smooth_l1_loss', - 'squeeze2', - 'unsqueeze2', - 'lod_reset', - 'lrn', - 'pad', - 'pad_constant_like', - 'label_smooth', - 'scatter', - 'sequence_scatter', - 'random_crop', - 'mean_iou', - 'selu', - 'crop', - 'affine_grid', - 'rank_loss', - 'margin_rank_loss', - 'pad2d', - 'elu', - 'pow', - 'stanh', - 'hard_sigmoid', - 'swish', - 'prelu', - 'brelu', - 'sequence_enumerate', - 'sequence_mask', - 'expand', - 'sampling_id', - 'maxout', - 'space_to_depth', - 'sequence_reverse', - 'similarity_focus', - 'hash', - 'grid_sampler', - 'log_loss', - 'teacher_student_sigmoid_loss', - 'add_position_encoding', - 'bilinear_tensor_product', - 'shuffle_channel', - 'temporal_shift', - 'psroi_pool', - 'huber_loss', - 'kldiv_loss', - 'tree_conv', - 'pixel_shuffle', - 'fsp', - 'cvm', - 'affine_channel', - 'roi_pool', - 'roi_align', - 'anchor_generator', - 'generate_proposals', - 'generate_proposal_labels', - 'generate_mask_labels', - 'lookup_table', - 'lookup_table_v2', -} +fp32_list -= bf16_list +fp32_list -= gray_list +unsupported_list -= bf16_list +unsupported_list -= gray_list From 6fb0cb88a96490d68a3284e4d92ee2cd4b0c4fd8 Mon Sep 17 00:00:00 2001 From: Artur Lesniak Date: Fri, 19 Mar 2021 17:17:49 +0100 Subject: [PATCH 33/33] Changes after review --- .../contrib/mixed_precision/bf16/amp_lists.py | 7 +- .../fluid/contrib/tests/test_bf16_utils.py | 138 ++++++------------ 2 files changed, 47 insertions(+), 98 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py index 216f55657b13f..81dc32d114b14 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py @@ -21,10 +21,9 @@ class AutoMixedPrecisionListsBF16(object): """ - AutoMixedPrecisionListsBF16 is a class for fp32/bf16 list. It can update - pre-defined fp32 list and bf16 list according to users' custom fp32 - bf16 lists. The lists are used for an algorithm which determines op's - execution mode (fp32 or bf16). + AutoMixedPrecisionListsBF16 is a class for fp32/bf16 op types list. The lists are used for an + algorithm which determines op's execution mode (fp32 or bf16).It can update pre-defined + fp32 list and bf16 list according to users' custom fp32 bf16 lists. Args: custom_bf16_list (set): Users' custom bf16 list. diff --git a/python/paddle/fluid/contrib/tests/test_bf16_utils.py b/python/paddle/fluid/contrib/tests/test_bf16_utils.py index 1f65aef540839..faf2307f8147b 100644 --- a/python/paddle/fluid/contrib/tests/test_bf16_utils.py +++ b/python/paddle/fluid/contrib/tests/test_bf16_utils.py @@ -16,140 +16,90 @@ import paddle.fluid as fluid import paddle.fluid.contrib.mixed_precision as amp from paddle.fluid import core -from paddle.fluid.contrib.mixed_precision import AutoMixedPrecisionListsBF16 import paddle paddle.enable_static() class AMPTest(unittest.TestCase): - def test_amp_lists(self): - bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) - fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) - gray_list = copy.copy(amp.bf16.amp_lists.gray_list) + def setUp(self): + self.bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) + self.fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) + self.gray_list = copy.copy(amp.bf16.amp_lists.gray_list) + self.amp_lists_ = None - amp_lists_ = amp.AutoMixedPrecisionListsBF16() - self.assertEqual(amp_lists_.bf16_list, bf16_list) - self.assertEqual(amp_lists_.fp32_list, fp32_list) - self.assertEqual(amp_lists_.gray_list, gray_list) + def tearDown(self): + self.assertEqual(self.amp_lists_.bf16_list, self.bf16_list) + self.assertEqual(self.amp_lists_.fp32_list, self.fp32_list) + self.assertEqual(self.amp_lists_.gray_list, self.gray_list) - def test_amp_lists_1(self): - bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) - fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) - gray_list = copy.copy(amp.bf16.amp_lists.gray_list) + def test_amp_lists(self): + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16() + def test_amp_lists_1(self): # 1. w={'exp}, b=None - bf16_list.add('exp') - fp32_list.remove('exp') + self.bf16_list.add('exp') + self.fp32_list.remove('exp') - amp_lists_ = amp.AutoMixedPrecisionListsBF16({'exp'}) - self.assertEqual(amp_lists_.bf16_list, bf16_list) - self.assertEqual(amp_lists_.fp32_list, fp32_list) - self.assertEqual(amp_lists_.gray_list, gray_list) + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'exp'}) def test_amp_lists_2(self): - bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) - fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) - gray_list = copy.copy(amp.bf16.amp_lists.gray_list) - # 2. w={'tanh'}, b=None - fp32_list.remove('tanh') - bf16_list.add('tanh') + self.fp32_list.remove('tanh') + self.bf16_list.add('tanh') - amp_lists_ = amp.AutoMixedPrecisionListsBF16({'tanh'}) - self.assertEqual(amp_lists_.bf16_list, bf16_list) - self.assertEqual(amp_lists_.fp32_list, fp32_list) - self.assertEqual(amp_lists_.gray_list, gray_list) + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'tanh'}) def test_amp_lists_3(self): - bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) - fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) - gray_list = copy.copy(amp.bf16.amp_lists.gray_list) - # 3. w={'lstm'}, b=None - bf16_list.add('lstm') + self.bf16_list.add('lstm') - amp_lists_ = amp.AutoMixedPrecisionListsBF16({'lstm'}) - self.assertEqual(amp_lists_.bf16_list, bf16_list) - self.assertEqual(amp_lists_.fp32_list, fp32_list) - self.assertEqual(amp_lists_.gray_list, gray_list) + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'lstm'}) def test_amp_lists_4(self): - bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) - fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) - gray_list = copy.copy(amp.bf16.amp_lists.gray_list) - # 4. w=None, b={'elementwise_add'} - bf16_list.remove('elementwise_add') - fp32_list.add('elementwise_add') + self.bf16_list.remove('elementwise_add') + self.fp32_list.add('elementwise_add') - amp_lists_ = amp.AutoMixedPrecisionListsBF16( + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( custom_fp32_list={'elementwise_add'}) - self.assertEqual(amp_lists_.bf16_list, bf16_list) - self.assertEqual(amp_lists_.fp32_list, fp32_list) - self.assertEqual(amp_lists_.gray_list, gray_list) def test_amp_lists_5(self): - bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) - fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) - gray_list = copy.copy(amp.bf16.amp_lists.gray_list) - # 5. w=None, b={'elementwise_add'} - fp32_list.add('elementwise_add') - bf16_list.remove('elementwise_add') + self.fp32_list.add('elementwise_add') + self.bf16_list.remove('elementwise_add') - amp_lists_ = amp.AutoMixedPrecisionListsBF16( + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( custom_fp32_list={'elementwise_add'}) - self.assertEqual(amp_lists_.bf16_list, bf16_list) - self.assertEqual(amp_lists_.fp32_list, fp32_list) - self.assertEqual(amp_lists_.gray_list, gray_list) def test_amp_lists_6(self): - bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) - fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) - gray_list = copy.copy(amp.bf16.amp_lists.gray_list) - # 6. w=None, b={'lstm'} - fp32_list.add('lstm') + self.fp32_list.add('lstm') - amp_lists_ = amp.AutoMixedPrecisionListsBF16(custom_fp32_list={'lstm'}) - self.assertEqual(amp_lists_.bf16_list, bf16_list) - self.assertEqual(amp_lists_.fp32_list, fp32_list) - self.assertEqual(amp_lists_.gray_list, gray_list) + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( + custom_fp32_list={'lstm'}) def test_amp_lists_7(self): - # 7. w={'lstm'} b={'lstm'} - # raise ValueError - self.assertRaises(ValueError, amp.AutoMixedPrecisionListsBF16, - {'lstm'}, {'lstm'}) - - def test_amp_lists_8(self): - bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) - fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) - gray_list = copy.copy(amp.bf16.amp_lists.gray_list) + self.fp32_list.add('reshape2') + self.gray_list.remove('reshape2') - fp32_list.add('reshape2') - gray_list.remove('reshape2') - - amp_lists_ = amp.AutoMixedPrecisionListsBF16( + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( custom_fp32_list={'reshape2'}) - self.assertEqual(amp_lists_.bf16_list, bf16_list) - self.assertEqual(amp_lists_.fp32_list, fp32_list) - self.assertEqual(amp_lists_.gray_list, gray_list) - - def test_amp_list_9(self): - bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list) - fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list) - gray_list = copy.copy(amp.bf16.amp_lists.gray_list) - bf16_list.add('reshape2') - gray_list.remove('reshape2') + def test_amp_list_8(self): + self.bf16_list.add('reshape2') + self.gray_list.remove('reshape2') - amp_lists_ = amp.AutoMixedPrecisionListsBF16( + self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( custom_bf16_list={'reshape2'}) - self.assertEqual(amp_lists_.bf16_list, bf16_list) - self.assertEqual(amp_lists_.fp32_list, fp32_list) - self.assertEqual(amp_lists_.gray_list, gray_list) + + +class AMPTest2(unittest.TestCase): + def test_amp_lists_(self): + # 7. w={'lstm'} b={'lstm'} + # raise ValueError + self.assertRaises(ValueError, amp.AutoMixedPrecisionListsBF16, + {'lstm'}, {'lstm'}) def test_find_op_index(self): block = fluid.default_main_program().global_block()