diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index d5f077d10119c..0b127d2a11f76 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -72,6 +72,7 @@ set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_nearest_interp_v2_op PROPERTIES TIMEOUT 30) set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT 120) +set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 120) if (WITH_MKLDNN) set_tests_properties(test_mkldnn_prelu_op PROPERTIES TIMEOUT 300) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py index 1a378da0bdbce..f244b875dd4cb 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py @@ -31,7 +31,8 @@ from program_config import TensorConfig, OpConfig, ProgramConfig, create_fake_model, create_quant_model import hypothesis -from hypothesis import given, settings, seed, example, assume +from hypothesis import given, settings, seed, reproduce_failure +import hypothesis.strategies as st logging.basicConfig(level=logging.INFO, format="%(message)s") @@ -78,6 +79,11 @@ def __init__(self, *args, **kwargs): abs_dir = os.path.abspath(os.path.dirname(__file__)) self.cache_dir = os.path.join(abs_dir, str(self.__module__) + '_cache_dir') + self.available_passes_in_framework = set() + self.num_ran_programs = 0 + self.num_invalid_programs = 0 + self.num_skipped_tests = 0 + self.num_predictor_kinds = 0 @abc.abstractmethod def sample_program_configs(self): @@ -99,9 +105,8 @@ def add_skip_case( note: str): self.skip_cases.append((teller, reason, note)) - @abc.abstractmethod def is_program_valid(self, program_config: ProgramConfig) -> bool: - raise NotImplementedError + return True def run_test_config(self, model, params, prog_config, pred_config, feed_data) -> Dict[str, np.ndarray]: @@ -110,6 +115,8 @@ def run_test_config(self, model, params, prog_config, pred_config, ''' pred_config.set_model_buffer(model, len(model), params, len(params)) predictor = paddle_infer.create_predictor(pred_config) + self.available_passes_in_framework = self.available_passes_in_framework | set( + pred_config.pass_builder().all_passes()) for name, _ in prog_config.inputs.items(): input_tensor = predictor.get_input_handle(name) @@ -277,39 +284,118 @@ def __init__(self, *args, **kwargs): def check_op_version(self): status = True for pass_name in self.passes: + if pass_name not in self.available_passes_in_framework: + continue if not PassVersionChecker.IsCompatible(pass_name): self.fail_log('{} version check failed.'.format(pass_name)) status = False return status - def assert_op_size(self, fusion_before_num, fusion_after_num, origin_model): + def add_skip_pass_case(self): + return + + def assert_op_list(self, op_list_after_fusion): if not self.passes: raise ValueError( - 'In PassAutoScan you should give a valid pass name.') + "In PassAutoScan you should give a valid pass name.") last_passed_program = os.path.join(self.cache_dir, - self.passes[-1] + '.pdmodel') + self.passes[-1] + ".pdmodel") + if not os.path.exists(last_passed_program): + raise ValueError( + "Cannot find file {}, please make sure that your pass name is correct". + format(last_passed_program)) model_bytes = paddle.static.load_from_file(last_passed_program) pg = paddle.static.deserialize_program(model_bytes) main_block = pg.desc.block(0) - after_op_size = main_block.op_size() - pg = paddle.static.deserialize_program(origin_model) - main_block = pg.desc.block(0) - before_op_size = main_block.op_size() - self.assertTrue(before_op_size == fusion_before_num, - 'before fusion op size is {}, but got {}!'.format( - before_op_size, fusion_before_num)) - self.assertTrue(after_op_size == fusion_after_num, - 'after fusion op size is {}, but got {}!'.format( - after_op_size, fusion_after_num)) + after_op_list = list() + for i in range(main_block.op_size()): + if main_block.op(i).type() in ["feed", "fetch"]: + continue + after_op_list.append(main_block.op(i).type()) + self.assertTrue( + op_list_after_fusion == after_op_list, + "Expected operator list after fusion is {}, but now it's {}".format( + op_list_after_fusion, after_op_list), ) - def run_test(self, quant=False, *args, **kwargs): + def run_and_statis( + self, + quant=False, + max_examples=100, + reproduce=None, + min_success_num=25, + max_duration=180, + passes=None, ): + if os.getenv('HYPOTHESIS_TEST_PROFILE', 'ci') == "dev": + max_examples *= 10 + min_success_num *= 10 + # while at ce phase, there's no limit on time + max_duration = -1 + start_time = time.time() + settings.register_profile( + "ci", + max_examples=max_examples, + suppress_health_check=hypothesis.HealthCheck.all(), + deadline=None, + print_blob=True, + derandomize=True, + report_multiple_bugs=False, ) + settings.load_profile("ci") + assert passes is not None, "Parameter of passes must be defined in function run_and_statis." + self.passes = passes + + self.add_skip_pass_case() + + def program_generator(draw): + return self.sample_program_config(draw) + + def run_test(prog_config): + return self.run_test(quant=quant, prog_configs=[prog_config]) + + generator = st.composite(program_generator) + loop_func = given(generator())(run_test) + if reproduce is not None: + loop_func = reproduce(loop_func) + logging.info("Start to running test of {}".format(type(self))) + loop_func() + logging.info( + "===================Statistical Information===================") + logging.info("Number of Generated Programs: {}".format( + self.num_ran_programs + self.num_invalid_programs)) + logging.info("Number of Invalid Programs: {}".format( + self.num_invalid_programs)) + logging.info("Number of Ran Programs: {}".format(self.num_ran_programs)) + logging.info("Number of Skipped Tests: {}".format( + self.num_skipped_tests)) + successful_ran_programs = int(self.num_ran_programs - + self.num_skipped_tests / + self.num_predictor_kinds) + logging.info( + "Number of successfully ran programs approximately equal to {}". + format(successful_ran_programs)) + if successful_ran_programs < min_success_num: + logging.warning( + "satisfied_programs = ran_programs - num_skipped_tests / num_predictor_kinds" + ) + logging.error( + "At least {} programs need to ran successfully, but now only about {} programs satisfied.". + format(min_success_num, successful_ran_programs)) + assert False + used_time = time.time() - start_time + if max_duration > 0 and used_time > max_duration: + logging.error( + "The duration exceeds {} seconds, if this is neccessary, try to set a larger number for parameter `max_duration`.". + format(max_duration)) + assert False + + def run_test(self, quant=False, prog_configs=None): status = True - for prog_config in self.sample_program_configs(*args, **kwargs): + for prog_config in prog_configs: # if program is invalid, we should skip that cases. if not self.is_program_valid(prog_config): + self.num_invalid_programs += 1 continue - + self.num_ran_programs += 1 model, params = create_fake_model(prog_config) if quant: model, params = create_quant_model(model, params) @@ -330,13 +416,16 @@ def run_test(self, quant=False, *args, **kwargs): feed_data)) self.success_log('RUN_CPU_BASELINE done') - for pred_config, nodes_num, ( + self.num_predictor_kinds = 0 + for pred_config, op_list, ( atol, rtol) in self.sample_predictor_configs(prog_config): + self.num_predictor_kinds += 1 # skip info skip_flag = False for skip_info in self.skip_cases: if skip_info[0](prog_config, pred_config): skip_flag = True + self.num_skipped_tests += 1 if skip_info[1] == SkipReasons.PASS_ACCURACY_ERROR: self.skip_log("[PASS_ACCURACY_ERROR] " + skip_info[ 2] + ' ' + ' vs ' + self.inference_config_str( @@ -357,7 +446,7 @@ def run_test(self, quant=False, *args, **kwargs): self.assert_tensors_near(atol, rtol, results[-1], results[0]) if not skip_flag: - self.assert_op_size(nodes_num[0], nodes_num[1], model) + self.assert_op_list(op_list) except Exception as e: self.fail_log( diff --git a/python/paddle/fluid/tests/unittests/ir/inference/program_config.py b/python/paddle/fluid/tests/unittests/ir/inference/program_config.py index 3479d4785d524..a58be906762cf 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/program_config.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/program_config.py @@ -34,17 +34,24 @@ class TensorConfig: def __init__(self, lod: Optional[List[List[int]]]=None, - data_gen: Optional[Callable[..., np.array]]=None): + data_gen: Optional[Callable[..., np.array]]=None, + shape: Optional[List[List[int]]]=None): ''' shape: The shape of the tensor. dtype: The data type of the tensor. data: The value of WeightVar. for input, it should be None ''' self.lod = lod - self.data_gen = data_gen - self.data = data_gen() - self.dtype = data_gen().dtype - self.shape = data_gen().shape + if data_gen is not None: + self.data_gen = data_gen + self.data = data_gen() + self.dtype = data_gen().dtype + self.shape = data_gen().shape + else: + assert shape is not None, "While data_gen is not defined, shape must not be None" + self.data = np.random.normal(0.0, 1.0, shape).astype(np.float32) + self.shape = shape + self.dtype = self.data.dtype def __repr__(self): return str({'shape': self.shape, 'lod': self.lod, 'dtype': self.dtype}) @@ -57,11 +64,15 @@ def __init__(self, type: str, inputs: Dict[str, List[str]], outputs: Dict[str, List[str]], - attrs: Dict[str, Any]): + attrs: Dict[str, Any]=None, + **kwargs): self.type = type self.inputs = inputs self.outputs = outputs self.attrs = attrs + if self.attrs is None: + self.attrs = dict() + self.attrs.update(kwargs) def __repr__(self): log_str = self.type diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py index 3529b4084d56b..7bbf8502955ff 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py @@ -71,7 +71,19 @@ def is_program_valid(self, program_config: ProgramConfig) -> bool: return True - def sample_program_configs(self, *args, **kwargs): + def sample_program_config(self, draw): + is_sparse = draw(st.booleans()) + is_distributed = draw(st.booleans()) + padding_idx = draw(st.integers()) + axis = draw(st.integers(min_value=-4, max_value=4)) + op_type = draw(st.sampled_from(['lookup_table', 'lookup_table_v2'])) + epsilon = draw(st.floats(min_value=0, max_value=0.001)) + # begin_norm_axis has to be 2 + begin_norm_axis = 2 + batch_size = draw(st.integers(min_value=1, max_value=4)) + input_dim = draw(st.sampled_from([32, 64])) + weight_size = draw(st.sampled_from([[64, 64], [64, 32]])) + def generate_input(attrs): if attrs[0]['op_type'] == 'lookup_table': return np.random.randint( @@ -101,19 +113,19 @@ def generate_weight2(attrs): np.float32) attrs = [{ - 'is_sparse': kwargs['is_sparse'], - 'is_distributed': kwargs['is_distributed'], - 'padding_idx': kwargs['padding_idx'], - 'op_type': kwargs['op_type'] + 'is_sparse': is_sparse, + 'is_distributed': is_distributed, + 'padding_idx': padding_idx, + 'op_type': op_type }, { - 'axis': kwargs['axis'] + 'axis': axis }, { - 'begin_norm_axis': kwargs['begin_norm_axis'], - 'epsilon': kwargs['epsilon'] + 'begin_norm_axis': begin_norm_axis, + 'epsilon': epsilon }, { - 'batch_size': kwargs['batch_size'], - 'input_dim': kwargs['input_dim'], - 'weight_size': kwargs['weight_size'] + 'batch_size': batch_size, + 'input_dim': input_dim, + 'weight_size': weight_size }] emb_op1 = OpConfig( @@ -203,13 +215,12 @@ def generate_weight2(attrs): }, outputs=["layer_norm_output1"]) - yield program_config + return program_config def sample_predictor_configs(self, program_config): # only used in gpu passes and trt passes. - config = self.create_inference_config( - passes=['embedding_eltwise_layernorm_fuse_pass'], use_gpu=True) - yield config, (10, 5), (1e-5, 1e-5) + config = self.create_inference_config(use_gpu=True) + yield config, ['fused_embedding_eltwise_layernorm'], (1e-5, 1e-5) # trt static_shape config = self.create_trt_inference_config() config.enable_tensorrt_engine( @@ -219,7 +230,7 @@ def sample_predictor_configs(self, program_config): precision_mode=paddle_infer.PrecisionType.Float32, use_static=False, use_calib_mode=False) - yield config, (10, 5), (1e-5, 1e-5) + yield config, ['fused_embedding_eltwise_layernorm'], (1e-5, 1e-5) # trt dynamic_shape config = self.create_trt_inference_config() config.enable_tensorrt_engine( @@ -257,7 +268,7 @@ def sample_predictor_configs(self, program_config): "input_data2": [2, 128], "input_data3": [2, 128] }) - yield config, (10, 5), (1e-5, 1e-5) + yield config, ['fused_embedding_eltwise_layernorm'], (1e-5, 1e-5) def add_skip_pass_case(self): def teller1(program_config, predictor_config): @@ -272,26 +283,13 @@ def teller1(program_config, predictor_config): self.add_skip_case(teller1, SkipReasons.PASS_ACCURACY_ERROR, "The pass output has diff in a specific case.") - @given( - is_sparse=st.booleans(), - is_distributed=st.booleans(), - padding_idx=st.integers(), - axis=st.integers( - min_value=-4, max_value=4), - op_type=st.sampled_from(['lookup_table', 'lookup_table_v2']), - epsilon=st.floats( - min_value=0, max_value=0.001), - begin_norm_axis=st.integers( - min_value=-4, max_value=4), - batch_size=st.integers( - min_value=1, max_value=4), - input_dim=st.sampled_from([32, 64]), - weight_size=st.sampled_from([[64, 64], [64, 32]])) - def test(self, *args, **kwargs): - assume(kwargs['begin_norm_axis'] == 2) - - self.add_skip_pass_case() - self.run_test(quant=False, *args, **kwargs) + def test(self): + # this fuse need to fix, now there's no program can ran successfully + self.run_and_statis( + quant=False, + max_examples=50, + passes=["embedding_eltwise_layernorm_fuse_pass"], + min_success_num=0) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_fc_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_fc_fuse_pass.py index a62adcea3f943..34f611e0bb67b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_fc_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_fc_fuse_pass.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,42 +12,159 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - -import unittest +from auto_scan_test import PassAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig, OpConfig import numpy as np -from inference_pass_test import InferencePassTest -import paddle.fluid as fluid -import paddle.fluid.core as core -from paddle.fluid.core import AnalysisConfig -from paddle.fluid.core import PassVersionChecker - - -class FcFusePassTest(InferencePassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=[-1, 128, 768], dtype="float32") - data_y = fluid.data(name="y", shape=[-1, 128, 768], dtype="float32") - fc_out1 = fluid.layers.fc(input=data, - size=3072, - num_flatten_dims=2, - act="relu") - fc_out2 = fluid.layers.fc(input=fc_out1, - size=768, - num_flatten_dims=2) - - self.feeds = {"data": np.random.random((4, 128, 768)).astype("float32")} - self.fetch_list = [fc_out2] - - def test_check_output(self): - use_gpu = [False] - if core.is_compiled_with_cuda(): - use_gpu.append(True) - for i in range(len(use_gpu)): - self.check_output_with_option(use_gpu[i]) - - self.assertTrue(PassVersionChecker.IsCompatible('fc_fuse_pass')) +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestFcFusePass(PassAutoScanTest): + """ + x_var y_var(persistable) + \ / + mul bias_var(persistable) + | + mul_out_var bias_var(persistable) + \ / + elementwise_add + """ + + def sample_predictor_configs(self, program_config): + # cpu + before_num_ops = len(program_config.ops) + 2 + config = self.create_inference_config(use_gpu=False) + yield config, ["fc"], (1e-5, 1e-5) + + # for gpu + config = self.create_inference_config(use_gpu=True) + yield config, ["fc"], (1e-5, 1e-5) + + def add_skip_pass_case(self): + # Here we put some skip rules to avoid known bugs + def teller1(program_config, predictor_config): + # shape of bias should be [1, mul_y_shape[-1]] or [mul_y_shape[-1]] + x_shape = list(program_config.inputs["mul_x"].shape) + y_shape = list(program_config.weights["mul_y"].shape) + bias_shape = program_config.weights["bias"].shape + if (bias_shape != [y_shape[-1], ] and + bias_shape != [1, y_shape[-1]]): + return True + return False + + def teller2(program_config, predictor_config): + # TODO fuse has bug while axis != -1 + if program_config.ops[1].attrs["axis"] != -1: + return True + return False + + self.add_skip_case( + teller1, + SkipReasons.PASS_ACCURACY_ERROR, + "The pass output has diff while shape of bias is not [out_size] or [1, out_size].", + ) + self.add_skip_case( + teller2, + SkipReasons.PASS_ACCURACY_ERROR, + "The pass output has diff while axis of elementwise_add is not -1.", + ) + + def is_program_valid(self, prog_config): + add_x_rank = prog_config.ops[0].attrs["x_num_col_dims"] + 1 + add_y_rank = len(prog_config.weights["bias"].shape) + axis = prog_config.ops[1].attrs["axis"] + if add_x_rank == add_y_rank: + if axis != -1 or axis != 0: + return False + return True + + def sample_program_config(self, draw): + # 1. Generate shape of input:X of mul + x_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=4), min_size=2, max_size=4)) + # 2. Generate attr:x_num_col_dims/y_num_col_dims of mul + x_num_col_dims = draw( + st.integers( + min_value=1, max_value=len(x_shape) - 1)) + y_num_col_dims = 1 + # 3. Generate legal shape of input:Y of mul + y_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=8), min_size=2, max_size=2)) + y_shape[0] = int(np.prod(x_shape[x_num_col_dims:])) + # 4. Generate legal attr:axis of elementwise_add + mul_out_shape = x_shape[:x_num_col_dims] + y_shape[1:] + axis = draw(st.integers(min_value=-1, max_value=x_num_col_dims)) + # 5. Generate legal shape of input:Y of elementwise_add + if axis >= 0: + max_bias_rank = x_num_col_dims + 1 - axis + bias_rank = draw(st.integers(min_value=1, max_value=max_bias_rank)) + bias_shape = mul_out_shape[axis:axis + bias_rank] + else: + max_bias_rank = 1 + bias_rank = draw( + st.integers( + min_value=1, max_value=len(mul_out_shape))) + bias_shape = mul_out_shape[-1 * bias_rank:] + # 6. Random choose if use broadcast for elementwise_add, e.g [3, 4] -> [1, 4] + if draw(st.booleans()): + broadcast_dims = draw(st.integers(min_value=1, max_value=bias_rank)) + for i in range(0, broadcast_dims): + bias_shape[i] = 1 + # 7. Random choose if add a relu operator + has_relu = draw(st.booleans()) + + # Now we have all the decided parameters to compose a program + # shape of inputs/weights tensors: x_shape, y_shape, bias_shape... + # parameters of operators: x_num_col_dims, y_num_col_dims, axis... + # a random boolean value(has_relu) to decide if program include a relu op + + # Here we will compose a program + # Still has some risks that the program is invalid or cause bug while running + # Use function `is_program_valid` to filter the invalid programs before running + # Use function `add_skip_pass_case` to ignore the programs even if they cause bug while runing + mul_op = OpConfig( + "mul", + inputs={"X": ["mul_x"], + "Y": ["mul_y"]}, + outputs={"Out": ["mul_out"]}, + x_num_col_dims=x_num_col_dims, + y_num_col_dims=y_num_col_dims, ) + add_op = OpConfig( + "elementwise_add", + inputs={"X": ["mul_out"], + "Y": ["bias"]}, + outputs={"Out": ["add_out"]}, + axis=axis, ) + ops = [mul_op, add_op] + if has_relu: + relu_op = OpConfig( + "relu", + inputs={"X": ["add_out"]}, + outputs={"Out": ["relu_out"]}) + ops.append(relu_op) + program_config = ProgramConfig( + ops=ops, + weights={ + "mul_y": TensorConfig(shape=y_shape), + "bias": TensorConfig(shape=bias_shape), + }, + inputs={"mul_x": TensorConfig(shape=x_shape), }, + outputs=ops[-1].outputs["Out"], ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, max_examples=300, passes=["fc_fuse_pass"]) if __name__ == "__main__":