Skip to content

Commit 3c9f687

Browse files
committed
Fixed some tests
1 parent 19ef156 commit 3c9f687

File tree

1 file changed

+37
-19
lines changed

1 file changed

+37
-19
lines changed

tests/utils/test_decorators.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,22 @@ def test_backend_requirement_support_checks():
7272
"""Test the backend_requirement decorator support checks."""
7373

7474
@supported_compute_capability([80, 86, 89, 90])
75-
def _cutlass_check(x, backend):
76-
return x.shape[0] > 0
75+
def _cudnn_check_my_kernel(x, backend):
76+
return True
7777

7878
@supported_compute_capability([75, 80, 86, 89, 90])
79-
def _cudnn_check(x, backend):
80-
return x.shape[0] > 0
79+
def _cutlass_check_my_kernel(x, backend):
80+
return True
8181

82-
@backend_requirement({"cutlass": _cutlass_check, "cudnn": _cudnn_check})
83-
def my_kernel(x, backend="cutlass"):
82+
def _common_check(x, backend):
83+
# Common requirement: must be 2D
84+
return x.dim() == 2
85+
86+
@backend_requirement(
87+
{"cudnn": _cudnn_check_my_kernel, "cutlass": _cutlass_check_my_kernel},
88+
common_check=_common_check,
89+
)
90+
def my_kernel(x, backend="cudnn"):
8491
return x * 2
8592

8693
# Check methods added
@@ -96,11 +103,14 @@ def my_kernel(x, backend="cutlass"):
96103

97104
# Check compute capability support
98105
assert my_kernel.is_backend_supported("cutlass", 80) is True
99-
assert my_kernel.is_backend_supported("cutlass", 75) is False
100-
assert my_kernel.is_backend_supported("cudnn", 75) is True
106+
assert my_kernel.is_backend_supported("cutlass", 75) is True # cutlass supports 75
107+
assert (
108+
my_kernel.is_backend_supported("cudnn", 75) is False
109+
) # cudnn does NOT support 75
110+
assert my_kernel.is_backend_supported("cudnn", 80) is True
101111

102112
# Check cross-backend compute capability
103-
assert my_kernel.is_compute_capability_supported(75) is True # cudnn has it
113+
assert my_kernel.is_compute_capability_supported(75) is True # cutlass has it
104114
assert my_kernel.is_compute_capability_supported(80) is True # both have it
105115
assert my_kernel.is_compute_capability_supported(70) is False # neither has it
106116

@@ -110,11 +120,16 @@ def test_backend_requirement_wrapped_function():
110120
if not torch.cuda.is_available():
111121
pytest.skip("Skipping CUDA tests (no GPU available)")
112122

113-
@supported_compute_capability([80, 86, 89, 90])
123+
# Get actual device capability
124+
x = torch.randn(1, 1, device="cuda")
125+
major, minor = torch.cuda.get_device_capability(x.device)
126+
actual_capability = major * 10 + minor
127+
128+
@supported_compute_capability([80, 86, 89, 90, actual_capability])
114129
def _cutlass_check(x, backend):
115130
return x.shape[0] > 0
116131

117-
@supported_compute_capability([75, 80, 86, 89, 90])
132+
@supported_compute_capability([75, 80, 86, 89, 90, actual_capability])
118133
def _cudnn_check(x, backend):
119134
return x.shape[0] > 0
120135

@@ -124,8 +139,11 @@ def my_kernel(x, backend="cutlass"):
124139

125140
x = torch.randn(10, 10, device="cuda")
126141

127-
# Test unsupported backend
128-
with pytest.raises(BackendSupportedError, match="trtllm"):
142+
# Test unsupported backend raises error
143+
# The error message may include capability info, so use a flexible pattern
144+
with pytest.raises(
145+
BackendSupportedError, match="does not support backend 'trtllm'"
146+
):
129147
my_kernel(x, backend="trtllm")
130148

131149
# Test supported backend works
@@ -143,34 +161,34 @@ def test_common_check():
143161
actual_capability = major * 10 + minor
144162

145163
@supported_compute_capability([80, 86, 89, 90, actual_capability])
146-
def _backend1_check(x, backend):
164+
def _cudnn_check_my_kernel(x, backend):
147165
return True
148166

149167
@supported_compute_capability([75, 80, 86, 89, 90, actual_capability])
150-
def _backend2_check(x, backend):
168+
def _cutlass_check_my_kernel(x, backend):
151169
return True
152170

153171
def _common_check(x, backend):
154172
# Common requirement: must be 2D
155173
return x.dim() == 2
156174

157175
@backend_requirement(
158-
{"backend1": _backend1_check, "backend2": _backend2_check},
176+
{"cudnn": _cudnn_check_my_kernel, "cutlass": _cutlass_check_my_kernel},
159177
common_check=_common_check,
160178
)
161-
def my_kernel(x, backend="backend1"):
179+
def my_kernel(x, backend="cudnn"):
162180
return x * 2
163181

164182
x_2d = torch.randn(10, 10, device="cuda")
165183
x_3d = torch.randn(10, 10, 10, device="cuda")
166184

167185
# 2D should work with skip_check
168-
result = my_kernel(x_2d, backend="backend1", skip_check=True)
186+
result = my_kernel(x_2d, backend="cudnn", skip_check=True)
169187
assert result.shape == x_2d.shape
170188

171189
# 3D should fail validation
172190
with pytest.raises(ValueError, match="Problem size is not supported"):
173-
my_kernel(x_3d, backend="backend1")
191+
my_kernel(x_3d, backend="cudnn")
174192

175193

176194
def test_functools_wraps_preserves_metadata():

0 commit comments

Comments
 (0)