From f8f7365bdfc152e74259e43e1199c3616abf2e5d Mon Sep 17 00:00:00 2001 From: arlesniak Date: Mon, 1 Mar 2021 16:49:11 +0100 Subject: [PATCH] Changes for CI --- .../contrib/mixed_precision/bf16_utils.py | 289 +++++++++--------- 1 file changed, 143 insertions(+), 146 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py index 88d1cca5ac699..1ad21fb721d36 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16_utils.py @@ -26,7 +26,10 @@ import logging import numpy as np -__all__ = ["bf16_guard", "convert_float_to_uint16", "convert_uint16_to_float"] +__all__ = [ + "bf16_guard", "cast_model_to_bf16", "convert_float_to_uint16", + "convert_uint16_to_float" +] _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') @@ -196,32 +199,6 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, return num_cast_ops -def find_true_prev_op(ops, cur_op, var_name): - """ - Find the true prev op that outputs var_name variable. - - Args: - ops (list): A list of ops. - cur_op (Operator): Current operator which has var_name variable. - var_name (string): Variable name. - """ - prev_op = [] - for op in ops: - if op == cur_op: - break - for out_name in op.output_names: - for out_var_name in op.output(out_name): - if out_var_name == var_name: - prev_op.append(op) - if prev_op: - if not len(prev_op) == 1: - raise ValueError("There must be only one previous op " - "that outputs {0} variable".format(var_name)) - else: - return prev_op[0] - return None - - def find_true_post_op(ops, cur_op, var_name): """ if there are post ops, return them, if there is no post op, @@ -246,27 +223,6 @@ def find_true_post_op(ops, cur_op, var_name): return post_op -def find_op_index(block_desc, cur_op_desc): - """ - """ - for idx in range(block_desc.op_size()): - if cur_op_desc == block_desc.op(idx): - return idx - return -1 - - -def _is_in_black_varnames(op, amp_lists): - for in_name in op.input_arg_names: - if in_name in amp_lists.black_varnames: - return True - - for out_name in op.output_arg_names: - if out_name in amp_lists.black_varnames: - return True - - return False - - def _need_keep_fp32(op, unsupported_op_list, use_bf16_guard): if op.type in unsupported_op_list: # the highest priority condition: If ops don't have bf16 computing kernels, @@ -320,108 +276,149 @@ def bf16_guard(): yield -def rewrite_program(main_prog, amp_lists, use_bf16_guard): +def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True): """ - Traverse all ops in current block and insert cast op according to - which set current op belongs to. - - 1. When an op belongs to the black list, add it to black set - 2. When an op belongs to the white list, add it to white set - 3. When an op belongs to the gray list. If one - of its inputs is the output of black set op or black list op, - add it to black set. If all of its previous ops are not black - op and one of its inputs is the output of white set op or - white list op, add it to white set. - 4. When an op isn't in the lists, add it to black op set. - 5. Add necessary cast ops to make sure that black set op will be - computed in fp32 mode, while white set op will be computed in - bf16 mode. - + Traverse all ops in the whole model and set their inputs and outputs + to the bf16 data type. This function will do some special process for + the batch normalization, which keeps the computational process of + batchnorms in FP32. Args: - main_prog (Program): The main program for training. + program (Program): The used program. + amp_lists (AutoMixedPrecisionListsBF16): An AutoMixedPrecisionListsBF16 object. + use_bf16_guard(bool): Determine whether to use `bf16_guard` when + constructing the program. Default True. """ - block = main_prog.global_block() - ops = block.ops - white_op_set = set() - black_op_set = set() - for op in ops: - - # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder, - # we don't need to handle reader op and the input of 'create_py_reader' is not - # in block, which may result in errors. - # See GeneratorLoader._init_non_iterable() for details. - if op.type == 'create_py_reader' or op.type == 'read': - continue - - if amp_lists.black_varnames is not None and _is_in_black_varnames( - op, amp_lists): - black_op_set.add(op) - continue - - if op.type in amp_lists.black_list: - black_op_set.add(op) - elif op.type in amp_lists.white_list: - white_op_set.add(op) - elif op.type in amp_lists.gray_list: - is_black_op = False - is_white_op = False + + if amp_lists is None: + amp_lists = AutoMixedPrecisionListsBF16() + global_block = program.global_block() + keep_fp32_ops = set() + to_bf16_var_names = set() + to_bf16_pre_cast_ops = set() + origin_ops = [] + for block in program.blocks: + origin_ops.extend(block.ops) + + for block in program.blocks: + ops = block.ops + for op in ops: + if op.type == 'create_py_reader' or op.type == 'read': + continue + if _need_keep_fp32(op, amp_lists.unsupported_list, use_bf16_guard): + keep_fp32_ops.add(op) + continue # processed below for in_name in op.input_names: - # if this op has inputs - if in_name: - for in_var_name in op.input(in_name): + if op.type in { + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + } and in_name not in {'X', 'Z'}: + continue + for in_var_name in op.input(in_name): + in_var = None + try: in_var = block.var(in_var_name) - # this in_var isn't the output of other op - if in_var.op is None: - continue - elif in_var.op is op: - prev_op = find_true_prev_op(ops, op, in_var_name) - if prev_op is None: - continue + except ValueError as e: + _logger.debug( + "-- {}, try to get it in the global block --". + format(e)) + in_var = global_block.var(in_var_name) + if in_var is not None: + _logger.debug( + "-- var {} is got in the global block --". + format(in_var_name)) + + if in_var is None or in_var.type not in _valid_types: + continue + + if in_var.dtype == core.VarDesc.VarType.FP32: + if in_var.is_data: + to_bf16_pre_cast_ops.add(op) else: - prev_op = in_var.op - # if it's one of inputs - if prev_op in black_op_set or \ - prev_op.type in amp_lists.black_list: - is_black_op = True - elif prev_op in white_op_set or \ - prev_op.type in amp_lists.white_list: - is_white_op = True - if is_black_op: - black_op_set.add(op) - elif is_white_op: - white_op_set.add(op) - else: - pass - else: - # For numerical safe, we apply fp32 computation on ops that - # are not determined which list they should stay. - black_op_set.add(op) - - idx = 0 - while idx < len(ops): - op = ops[idx] - num_cast_ops = 0 - if op in black_op_set: - num_cast_ops = _insert_cast_op(block, op, idx, - core.VarDesc.VarType.BF16, - core.VarDesc.VarType.FP32) - elif op in white_op_set: - if use_bf16_guard: - if not (op.has_attr('op_namescope') and - (_bf16_guard_pattern in op.attr("op_namescope"))): - idx += 1 - continue - if op.has_attr('use_mkldnn'): - op._set_attr('use_mkldnn', True) - op._set_attr('mkldnn_data_type', 'bfloat16') - elif op.has_attr('dtype') and op.attr( - 'dtype') == core.VarDesc.VarType.FP32: - op._set_attr('dtype', core.VarDesc.VarType.BF16) - - num_cast_ops = _insert_cast_op(block, op, idx, - core.VarDesc.VarType.FP32, - core.VarDesc.VarType.BF16) - else: - pass + in_var.desc.set_dtype(core.VarDesc.VarType.BF16) + to_bf16_var_names.add(in_var_name) + + _logger.debug( + "-- op type: {}, in var name: {}, in var dtype: {} --". + format(op.type, in_var_name, in_var.dtype)) - idx += num_cast_ops + 1 + for out_name in op.output_names: + if op.type in { + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + } and out_name != 'Y': + continue + for out_var_name in op.output(out_name): + out_var = None + try: + out_var = block.var(out_var_name) + except ValueError as e: + _logger.debug( + "-- {}, try to get it in the global block --". + format(e)) + out_var = global_block.var(out_var_name) + if out_var is not None: + _logger.debug( + "-- var {} is got in the global block --". + format(out_var_name)) + + if out_var is None or out_var.type not in _valid_types: + continue + + if out_var.dtype == core.VarDesc.VarType.FP32: + out_var.desc.set_dtype(core.VarDesc.VarType.BF16) + + _logger.debug( + "-- op type: {}, out var name: {}, out var dtype: {} --". + format(op.type, out_var_name, out_var.dtype)) + if op.has_attr('in_dtype') and op.attr( + 'in_dtype') == core.VarDesc.VarType.FP32: + op._set_attr('in_dtype', core.VarDesc.VarType.BF16) + if op.has_attr('out_dtype') and op.attr( + 'out_dtype') == core.VarDesc.VarType.FP32: + op._set_attr('out_dtype', core.VarDesc.VarType.BF16) + if op.has_attr('dtype') and op.attr( + 'dtype') == core.VarDesc.VarType.FP32: + op._set_attr('dtype', core.VarDesc.VarType.BF16) + if op.has_attr('use_mkldnn'): + op._set_attr('use_mkldnn', True) + op._set_attr('mkldnn_data_type', 'bfloat16') + + # process ops in keep_fp32_ops + op_var_rename_map = [ + collections.OrderedDict() for _ in range(len(program.blocks)) + ] + for block in program.blocks: + ops = block.ops + idx = 0 + while idx < len(ops): + op = ops[idx] + num_cast_ops = 0 + if op not in keep_fp32_ops: + if op in to_bf16_pre_cast_ops: + in_var_cast_num = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.BF16) + num_cast_ops += in_var_cast_num + else: + pre_cast_num = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.BF16, + core.VarDesc.VarType.FP32) + num_cast_ops += pre_cast_num + for out_var_name in op.output_arg_names: + out_var = block.vars.get(out_var_name) + if out_var is None or out_var.type not in _valid_types: + continue + if out_var.dtype == core.VarDesc.VarType.BF16: + out_var.desc.set_dtype(core.VarDesc.VarType.FP32) + post_ops = find_true_post_op(ops, op, out_var_name) + for post_op in post_ops: + if post_op in keep_fp32_ops: + continue + post_cast_num = _insert_cast_post_op( + block, op, idx + pre_cast_num + 1, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.BF16, out_var_name, + op_var_rename_map) + num_cast_ops += post_cast_num + idx += num_cast_ops + 1 + + _rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops) + return to_bf16_var_names