Skip to content

Commit

Permalink
Changes for CI
Browse files Browse the repository at this point in the history
  • Loading branch information
arlesniak committed Mar 1, 2021
1 parent 87eb108 commit ff9f83a
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 619 deletions.
3 changes: 0 additions & 3 deletions python/paddle/fluid/contrib/mixed_precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

from . import decorator
from .decorator import *
from . import decorator_bf16
from .decorator_bf16 import *
from . import fp16_lists
from .fp16_lists import *
from . import bf16_lists
Expand All @@ -28,7 +26,6 @@
from .bf16_utils import *

__all__ = decorator.__all__
__all__ += decorator_bf16.__all__
__all__ += fp16_lists.__all__
__all__ += bf16_lists.__all__
__all__ += fp16_utils.__all__
Expand Down
180 changes: 1 addition & 179 deletions python/paddle/fluid/contrib/mixed_precision/bf16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
import logging
import numpy as np

__all__ = [
"bf16_guard", "cast_model_to_bf16", "cast_parameters_to_bf16",
"convert_float_to_uint16", "convert_uint16_to_float"
]
__all__ = ["bf16_guard", "convert_float_to_uint16", "convert_uint16_to_float"]

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
Expand Down Expand Up @@ -323,181 +320,6 @@ def bf16_guard():
yield


def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True):
"""
Traverse all ops in the whole model and set their inputs and outputs
to the bf16 data type. This function will do some special process for
the batch normalization, which keeps the computational process of
batchnorms in FP32.
Args:
program (Program): The used program.
amp_lists (AutoMixedPrecisionListsBF16): An AutoMixedPrecisionListsBF16 object.
use_bf16_guard(bool): Determine whether to use `bf16_guard` when
constructing the program. Default True.
"""

if amp_lists is None:
amp_lists = AutoMixedPrecisionListsBF16()
global_block = program.global_block()
keep_fp32_ops = set()
to_bf16_var_names = set()
to_bf16_pre_cast_ops = set()
origin_ops = []
for block in program.blocks:
origin_ops.extend(block.ops)

for block in program.blocks:
ops = block.ops
for op in ops:
if op.type == 'create_py_reader' or op.type == 'read':
continue
if _need_keep_fp32(op, amp_lists.unsupported_list, use_bf16_guard):
keep_fp32_ops.add(op)
continue # processed below
for in_name in op.input_names:
if op.type in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
} and in_name not in {'X', 'Z'}:
continue
for in_var_name in op.input(in_name):
in_var = None
try:
in_var = block.var(in_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block --".
format(e))
in_var = global_block.var(in_var_name)
if in_var is not None:
_logger.debug(
"-- var {} is got in the global block --".
format(in_var_name))

if in_var is None or in_var.type not in _valid_types:
continue

if in_var.dtype == core.VarDesc.VarType.FP32:
if in_var.is_data:
to_bf16_pre_cast_ops.add(op)
else:
in_var.desc.set_dtype(core.VarDesc.VarType.BF16)
to_bf16_var_names.add(in_var_name)

_logger.debug(
"-- op type: {}, in var name: {}, in var dtype: {} --".
format(op.type, in_var_name, in_var.dtype))

for out_name in op.output_names:
if op.type in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
} and out_name != 'Y':
continue
for out_var_name in op.output(out_name):
out_var = None
try:
out_var = block.var(out_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block --".
format(e))
out_var = global_block.var(out_var_name)
if out_var is not None:
_logger.debug(
"-- var {} is got in the global block --".
format(out_var_name))

if out_var is None or out_var.type not in _valid_types:
continue

if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)

_logger.debug(
"-- op type: {}, out var name: {}, out var dtype: {} --".
format(op.type, out_var_name, out_var.dtype))
if op.has_attr('in_dtype') and op.attr(
'in_dtype') == core.VarDesc.VarType.FP32:
op._set_attr('in_dtype', core.VarDesc.VarType.BF16)
if op.has_attr('out_dtype') and op.attr(
'out_dtype') == core.VarDesc.VarType.FP32:
op._set_attr('out_dtype', core.VarDesc.VarType.BF16)
if op.has_attr('dtype') and op.attr(
'dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.BF16)
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
op._set_attr('mkldnn_data_type', 'bfloat16')

# process ops in keep_fp32_ops
op_var_rename_map = [
collections.OrderedDict() for _ in range(len(program.blocks))
]
for block in program.blocks:
ops = block.ops
idx = 0
while idx < len(ops):
op = ops[idx]
num_cast_ops = 0
if op not in keep_fp32_ops:
if op in to_bf16_pre_cast_ops:
in_var_cast_num = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16)
num_cast_ops += in_var_cast_num
else:
pre_cast_num = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.BF16,
core.VarDesc.VarType.FP32)
num_cast_ops += pre_cast_num
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.BF16:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
post_ops = find_true_post_op(ops, op, out_var_name)
for post_op in post_ops:
if post_op in keep_fp32_ops:
continue
post_cast_num = _insert_cast_post_op(
block, op, idx + pre_cast_num + 1,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16, out_var_name,
op_var_rename_map)
num_cast_ops += post_cast_num
idx += num_cast_ops + 1

_rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops)
return to_bf16_var_names


def cast_parameters_to_bf16(place, program, scope=None, to_bf16_var_names=None):
"""
Traverse all parameters in the whole model and set them to the BF16 data type.
Whereas, this function will keep parameters of batchnorms in FP32.
Args:
place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the BF16 weight tensors.
program (Program): The used program.
scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values.
Default is None.
to_bf16_var_names(set|list, optional): The data types of vars in `to_bf16_var_names`
will be set to BF16. Usually, it is the returned
value of `cast_model_to_bf16` API.
"""
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())

bf16_var_names = to_bf16_var_names if to_bf16_var_names else set()
var_scope = scope if scope else global_scope()
for param in all_parameters:
if param.name in bf16_var_names:
_logger.debug("---- cast {} to bf16 dtype ----".format(param.name))
param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t)
param_t.set(np.uint16(data), place)


def rewrite_program(main_prog, amp_lists, use_bf16_guard):
"""
Traverse all ops in current block and insert cast op according to
Expand Down
Loading

0 comments on commit ff9f83a

Please sign in to comment.