Skip to content

Commit

Permalink
Changes for CI
Browse files Browse the repository at this point in the history
  • Loading branch information
arlesniak committed Mar 1, 2021
1 parent ff9f83a commit f8f7365
Showing 1 changed file with 143 additions and 146 deletions.
289 changes: 143 additions & 146 deletions python/paddle/fluid/contrib/mixed_precision/bf16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
import logging
import numpy as np

__all__ = ["bf16_guard", "convert_float_to_uint16", "convert_uint16_to_float"]
__all__ = [
"bf16_guard", "cast_model_to_bf16", "convert_float_to_uint16",
"convert_uint16_to_float"
]

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
Expand Down Expand Up @@ -196,32 +199,6 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
return num_cast_ops


def find_true_prev_op(ops, cur_op, var_name):
"""
Find the true prev op that outputs var_name variable.
Args:
ops (list): A list of ops.
cur_op (Operator): Current operator which has var_name variable.
var_name (string): Variable name.
"""
prev_op = []
for op in ops:
if op == cur_op:
break
for out_name in op.output_names:
for out_var_name in op.output(out_name):
if out_var_name == var_name:
prev_op.append(op)
if prev_op:
if not len(prev_op) == 1:
raise ValueError("There must be only one previous op "
"that outputs {0} variable".format(var_name))
else:
return prev_op[0]
return None


def find_true_post_op(ops, cur_op, var_name):
"""
if there are post ops, return them, if there is no post op,
Expand All @@ -246,27 +223,6 @@ def find_true_post_op(ops, cur_op, var_name):
return post_op


def find_op_index(block_desc, cur_op_desc):
"""
"""
for idx in range(block_desc.op_size()):
if cur_op_desc == block_desc.op(idx):
return idx
return -1


def _is_in_black_varnames(op, amp_lists):
for in_name in op.input_arg_names:
if in_name in amp_lists.black_varnames:
return True

for out_name in op.output_arg_names:
if out_name in amp_lists.black_varnames:
return True

return False


def _need_keep_fp32(op, unsupported_op_list, use_bf16_guard):
if op.type in unsupported_op_list:
# the highest priority condition: If ops don't have bf16 computing kernels,
Expand Down Expand Up @@ -320,108 +276,149 @@ def bf16_guard():
yield


def rewrite_program(main_prog, amp_lists, use_bf16_guard):
def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True):
"""
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
1. When an op belongs to the black list, add it to black set
2. When an op belongs to the white list, add it to white set
3. When an op belongs to the gray list. If one
of its inputs is the output of black set op or black list op,
add it to black set. If all of its previous ops are not black
op and one of its inputs is the output of white set op or
white list op, add it to white set.
4. When an op isn't in the lists, add it to black op set.
5. Add necessary cast ops to make sure that black set op will be
computed in fp32 mode, while white set op will be computed in
bf16 mode.
Traverse all ops in the whole model and set their inputs and outputs
to the bf16 data type. This function will do some special process for
the batch normalization, which keeps the computational process of
batchnorms in FP32.
Args:
main_prog (Program): The main program for training.
program (Program): The used program.
amp_lists (AutoMixedPrecisionListsBF16): An AutoMixedPrecisionListsBF16 object.
use_bf16_guard(bool): Determine whether to use `bf16_guard` when
constructing the program. Default True.
"""
block = main_prog.global_block()
ops = block.ops
white_op_set = set()
black_op_set = set()
for op in ops:

# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
# we don't need to handle reader op and the input of 'create_py_reader' is not
# in block, which may result in errors.
# See GeneratorLoader._init_non_iterable() for details.
if op.type == 'create_py_reader' or op.type == 'read':
continue

if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists):
black_op_set.add(op)
continue

if op.type in amp_lists.black_list:
black_op_set.add(op)
elif op.type in amp_lists.white_list:
white_op_set.add(op)
elif op.type in amp_lists.gray_list:
is_black_op = False
is_white_op = False

if amp_lists is None:
amp_lists = AutoMixedPrecisionListsBF16()
global_block = program.global_block()
keep_fp32_ops = set()
to_bf16_var_names = set()
to_bf16_pre_cast_ops = set()
origin_ops = []
for block in program.blocks:
origin_ops.extend(block.ops)

for block in program.blocks:
ops = block.ops
for op in ops:
if op.type == 'create_py_reader' or op.type == 'read':
continue
if _need_keep_fp32(op, amp_lists.unsupported_list, use_bf16_guard):
keep_fp32_ops.add(op)
continue # processed below
for in_name in op.input_names:
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
if op.type in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
} and in_name not in {'X', 'Z'}:
continue
for in_var_name in op.input(in_name):
in_var = None
try:
in_var = block.var(in_var_name)
# this in_var isn't the output of other op
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(ops, op, in_var_name)
if prev_op is None:
continue
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block --".
format(e))
in_var = global_block.var(in_var_name)
if in_var is not None:
_logger.debug(
"-- var {} is got in the global block --".
format(in_var_name))

if in_var is None or in_var.type not in _valid_types:
continue

if in_var.dtype == core.VarDesc.VarType.FP32:
if in_var.is_data:
to_bf16_pre_cast_ops.add(op)
else:
prev_op = in_var.op
# if it's one of inputs
if prev_op in black_op_set or \
prev_op.type in amp_lists.black_list:
is_black_op = True
elif prev_op in white_op_set or \
prev_op.type in amp_lists.white_list:
is_white_op = True
if is_black_op:
black_op_set.add(op)
elif is_white_op:
white_op_set.add(op)
else:
pass
else:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
black_op_set.add(op)

idx = 0
while idx < len(ops):
op = ops[idx]
num_cast_ops = 0
if op in black_op_set:
num_cast_ops = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.BF16,
core.VarDesc.VarType.FP32)
elif op in white_op_set:
if use_bf16_guard:
if not (op.has_attr('op_namescope') and
(_bf16_guard_pattern in op.attr("op_namescope"))):
idx += 1
continue
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
op._set_attr('mkldnn_data_type', 'bfloat16')
elif op.has_attr('dtype') and op.attr(
'dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.BF16)

num_cast_ops = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16)
else:
pass
in_var.desc.set_dtype(core.VarDesc.VarType.BF16)
to_bf16_var_names.add(in_var_name)

_logger.debug(
"-- op type: {}, in var name: {}, in var dtype: {} --".
format(op.type, in_var_name, in_var.dtype))

idx += num_cast_ops + 1
for out_name in op.output_names:
if op.type in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
} and out_name != 'Y':
continue
for out_var_name in op.output(out_name):
out_var = None
try:
out_var = block.var(out_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block --".
format(e))
out_var = global_block.var(out_var_name)
if out_var is not None:
_logger.debug(
"-- var {} is got in the global block --".
format(out_var_name))

if out_var is None or out_var.type not in _valid_types:
continue

if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)

_logger.debug(
"-- op type: {}, out var name: {}, out var dtype: {} --".
format(op.type, out_var_name, out_var.dtype))
if op.has_attr('in_dtype') and op.attr(
'in_dtype') == core.VarDesc.VarType.FP32:
op._set_attr('in_dtype', core.VarDesc.VarType.BF16)
if op.has_attr('out_dtype') and op.attr(
'out_dtype') == core.VarDesc.VarType.FP32:
op._set_attr('out_dtype', core.VarDesc.VarType.BF16)
if op.has_attr('dtype') and op.attr(
'dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.BF16)
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
op._set_attr('mkldnn_data_type', 'bfloat16')

# process ops in keep_fp32_ops
op_var_rename_map = [
collections.OrderedDict() for _ in range(len(program.blocks))
]
for block in program.blocks:
ops = block.ops
idx = 0
while idx < len(ops):
op = ops[idx]
num_cast_ops = 0
if op not in keep_fp32_ops:
if op in to_bf16_pre_cast_ops:
in_var_cast_num = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16)
num_cast_ops += in_var_cast_num
else:
pre_cast_num = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.BF16,
core.VarDesc.VarType.FP32)
num_cast_ops += pre_cast_num
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.BF16:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
post_ops = find_true_post_op(ops, op, out_var_name)
for post_op in post_ops:
if post_op in keep_fp32_ops:
continue
post_cast_num = _insert_cast_post_op(
block, op, idx + pre_cast_num + 1,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16, out_var_name,
op_var_rename_map)
num_cast_ops += post_cast_num
idx += num_cast_ops + 1

_rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops)
return to_bf16_var_names

0 comments on commit f8f7365

Please sign in to comment.