diff --git a/test/legacy_test/test_fused_dconv_drelu_dbn_op.py b/test/legacy_test/test_fused_dconv_drelu_dbn_op.py index c9671bae176071..fbeea09c441fbc 100644 --- a/test/legacy_test/test_fused_dconv_drelu_dbn_op.py +++ b/test/legacy_test/test_fused_dconv_drelu_dbn_op.py @@ -31,17 +31,37 @@ def skip_unit_test(): return ( - not (paddle.is_compiled_with_cuda() or is_custom_device()) + not (paddle.base.libpaddle.is_compiled_with_cudnn_frontend()) + or not (paddle.is_compiled_with_cuda() or is_custom_device()) or paddle.device.cuda.get_device_capability()[0] < 8 ) -skip_msg = "only support with cuda and Ampere or later devices" +skip_msg = "only support with cuda and Ampere or later devices, also please ensure you have used compile mode to install paddlepaddle with -WITH_CUDNN_FRONTEND ON" @skip_check_grad_ci(reason="no grad op") @unittest.skipIf(skip_unit_test(), skip_msg) class TestFusedDconvDreluDbnOp(OpTest): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fuse_add = False + self.fuse_shortcut = False + self.fuse_dual = False + self.exhaustive_search = False + + def set_attrs( + self, + fuse_add=False, + fuse_shortcut=False, + fuse_dual=False, + exhaustive_search=False, + ): + self.fuse_add = fuse_add + self.fuse_shortcut = fuse_shortcut + self.fuse_dual = fuse_dual + self.exhaustive_search = exhaustive_search + def setUp(self): self.__class__.op_type = "fused_dconv_drelu_dbn" self.dtype = np.float16 @@ -431,53 +451,44 @@ def init_attr(self): @skip_check_grad_ci(reason="no grad op") @unittest.skipIf(skip_unit_test(), skip_msg) class TestFusedDconvDreluDbnOpShortcut(TestFusedDconvDreluDbnOp): - def init_attr(self): - self.fuse_add = False - self.fuse_shortcut = True - self.fuse_dual = False - self.exhaustive_search = False + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.set_attrs(fuse_shortcut=True) @skip_check_grad_ci(reason="no grad op") @unittest.skipIf(skip_unit_test(), skip_msg) class TestFusedDconvDreluDbnOpDual(TestFusedDconvDreluDbnOp): - def init_attr(self): - self.fuse_add = False - self.fuse_shortcut = False - self.fuse_dual = True - self.exhaustive_search = False + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.set_attrs(fuse_dual=True) @skip_check_grad_ci(reason="no grad op") @unittest.skipIf(skip_unit_test(), skip_msg) class TestFusedDconvDreluDbnOpShortcutAdd(TestFusedDconvDreluDbnOp): - def init_attr(self): - self.fuse_add = True - self.fuse_shortcut = True - self.fuse_dual = False - self.exhaustive_search = False + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.set_attrs(fuse_add=True, fuse_shortcut=True) @skip_check_grad_ci(reason="no grad op") @unittest.skipIf(skip_unit_test(), skip_msg) class TestFusedDconvDreluDbnOpDualAdd(TestFusedDconvDreluDbnOp): - def init_attr(self): - self.fuse_add = True - self.fuse_shortcut = False - self.fuse_dual = True - self.exhaustive_search = False + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.set_attrs(fuse_add=True, fuse_dual=True) @skip_check_grad_ci(reason="no grad op") @unittest.skipIf(skip_unit_test(), skip_msg) class TestFusedDconvDreluDbnOpExhaustive(TestFusedDconvDreluDbnOp): - def init_attr(self): - self.fuse_add = False - self.fuse_shortcut = False - self.fuse_dual = False - self.exhaustive_search = True + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.set_attrs(exhaustive_search=True) if __name__ == '__main__': - np.random.seed(0) - unittest.main() + for _ in range(10): + np.random.seed(np.random.randint(0, 1000)) + unittest.main()