Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 40 additions & 29 deletions test/legacy_test/test_fused_dconv_drelu_dbn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,37 @@

def skip_unit_test():
return (
not (paddle.is_compiled_with_cuda() or is_custom_device())
not (paddle.base.libpaddle.is_compiled_with_cudnn_frontend())
or not (paddle.is_compiled_with_cuda() or is_custom_device())
or paddle.device.cuda.get_device_capability()[0] < 8
)


skip_msg = "only support with cuda and Ampere or later devices"
skip_msg = "only support with cuda and Ampere or later devices, also please ensure you have used compile mode to install paddlepaddle with -WITH_CUDNN_FRONTEND ON"


@skip_check_grad_ci(reason="no grad op")
@unittest.skipIf(skip_unit_test(), skip_msg)
class TestFusedDconvDreluDbnOp(OpTest):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fuse_add = False
self.fuse_shortcut = False
self.fuse_dual = False
self.exhaustive_search = False

def set_attrs(
self,
fuse_add=False,
fuse_shortcut=False,
fuse_dual=False,
exhaustive_search=False,
):
self.fuse_add = fuse_add
self.fuse_shortcut = fuse_shortcut
self.fuse_dual = fuse_dual
self.exhaustive_search = exhaustive_search

def setUp(self):
self.__class__.op_type = "fused_dconv_drelu_dbn"
self.dtype = np.float16
Expand Down Expand Up @@ -431,53 +451,44 @@ def init_attr(self):
@skip_check_grad_ci(reason="no grad op")
@unittest.skipIf(skip_unit_test(), skip_msg)
class TestFusedDconvDreluDbnOpShortcut(TestFusedDconvDreluDbnOp):
def init_attr(self):
self.fuse_add = False
self.fuse_shortcut = True
self.fuse_dual = False
self.exhaustive_search = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.set_attrs(fuse_shortcut=True)


@skip_check_grad_ci(reason="no grad op")
@unittest.skipIf(skip_unit_test(), skip_msg)
class TestFusedDconvDreluDbnOpDual(TestFusedDconvDreluDbnOp):
def init_attr(self):
self.fuse_add = False
self.fuse_shortcut = False
self.fuse_dual = True
self.exhaustive_search = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.set_attrs(fuse_dual=True)


@skip_check_grad_ci(reason="no grad op")
@unittest.skipIf(skip_unit_test(), skip_msg)
class TestFusedDconvDreluDbnOpShortcutAdd(TestFusedDconvDreluDbnOp):
def init_attr(self):
self.fuse_add = True
self.fuse_shortcut = True
self.fuse_dual = False
self.exhaustive_search = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.set_attrs(fuse_add=True, fuse_shortcut=True)


@skip_check_grad_ci(reason="no grad op")
@unittest.skipIf(skip_unit_test(), skip_msg)
class TestFusedDconvDreluDbnOpDualAdd(TestFusedDconvDreluDbnOp):
def init_attr(self):
self.fuse_add = True
self.fuse_shortcut = False
self.fuse_dual = True
self.exhaustive_search = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.set_attrs(fuse_add=True, fuse_dual=True)


@skip_check_grad_ci(reason="no grad op")
@unittest.skipIf(skip_unit_test(), skip_msg)
class TestFusedDconvDreluDbnOpExhaustive(TestFusedDconvDreluDbnOp):
def init_attr(self):
self.fuse_add = False
self.fuse_shortcut = False
self.fuse_dual = False
self.exhaustive_search = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.set_attrs(exhaustive_search=True)


if __name__ == '__main__':
np.random.seed(0)
unittest.main()
for _ in range(10):
np.random.seed(np.random.randint(0, 1000))
unittest.main()