@@ -72,15 +72,22 @@ def test_backend_requirement_support_checks():
7272 """Test the backend_requirement decorator support checks."""
7373
7474 @supported_compute_capability ([80 , 86 , 89 , 90 ])
75- def _cutlass_check (x , backend ):
76- return x . shape [ 0 ] > 0
75+ def _cudnn_check_my_kernel (x , backend ):
76+ return True
7777
7878 @supported_compute_capability ([75 , 80 , 86 , 89 , 90 ])
79- def _cudnn_check (x , backend ):
80- return x . shape [ 0 ] > 0
79+ def _cutlass_check_my_kernel (x , backend ):
80+ return True
8181
82- @backend_requirement ({"cutlass" : _cutlass_check , "cudnn" : _cudnn_check })
83- def my_kernel (x , backend = "cutlass" ):
82+ def _common_check (x , backend ):
83+ # Common requirement: must be 2D
84+ return x .dim () == 2
85+
86+ @backend_requirement (
87+ {"cudnn" : _cudnn_check_my_kernel , "cutlass" : _cutlass_check_my_kernel },
88+ common_check = _common_check ,
89+ )
90+ def my_kernel (x , backend = "cudnn" ):
8491 return x * 2
8592
8693 # Check methods added
@@ -96,11 +103,14 @@ def my_kernel(x, backend="cutlass"):
96103
97104 # Check compute capability support
98105 assert my_kernel .is_backend_supported ("cutlass" , 80 ) is True
99- assert my_kernel .is_backend_supported ("cutlass" , 75 ) is False
100- assert my_kernel .is_backend_supported ("cudnn" , 75 ) is True
106+ assert my_kernel .is_backend_supported ("cutlass" , 75 ) is True # cutlass supports 75
107+ assert (
108+ my_kernel .is_backend_supported ("cudnn" , 75 ) is False
109+ ) # cudnn does NOT support 75
110+ assert my_kernel .is_backend_supported ("cudnn" , 80 ) is True
101111
102112 # Check cross-backend compute capability
103- assert my_kernel .is_compute_capability_supported (75 ) is True # cudnn has it
113+ assert my_kernel .is_compute_capability_supported (75 ) is True # cutlass has it
104114 assert my_kernel .is_compute_capability_supported (80 ) is True # both have it
105115 assert my_kernel .is_compute_capability_supported (70 ) is False # neither has it
106116
@@ -110,11 +120,16 @@ def test_backend_requirement_wrapped_function():
110120 if not torch .cuda .is_available ():
111121 pytest .skip ("Skipping CUDA tests (no GPU available)" )
112122
113- @supported_compute_capability ([80 , 86 , 89 , 90 ])
123+ # Get actual device capability
124+ x = torch .randn (1 , 1 , device = "cuda" )
125+ major , minor = torch .cuda .get_device_capability (x .device )
126+ actual_capability = major * 10 + minor
127+
128+ @supported_compute_capability ([80 , 86 , 89 , 90 , actual_capability ])
114129 def _cutlass_check (x , backend ):
115130 return x .shape [0 ] > 0
116131
117- @supported_compute_capability ([75 , 80 , 86 , 89 , 90 ])
132+ @supported_compute_capability ([75 , 80 , 86 , 89 , 90 , actual_capability ])
118133 def _cudnn_check (x , backend ):
119134 return x .shape [0 ] > 0
120135
@@ -124,8 +139,11 @@ def my_kernel(x, backend="cutlass"):
124139
125140 x = torch .randn (10 , 10 , device = "cuda" )
126141
127- # Test unsupported backend
128- with pytest .raises (BackendSupportedError , match = "trtllm" ):
142+ # Test unsupported backend raises error
143+ # The error message may include capability info, so use a flexible pattern
144+ with pytest .raises (
145+ BackendSupportedError , match = "does not support backend 'trtllm'"
146+ ):
129147 my_kernel (x , backend = "trtllm" )
130148
131149 # Test supported backend works
@@ -143,34 +161,34 @@ def test_common_check():
143161 actual_capability = major * 10 + minor
144162
145163 @supported_compute_capability ([80 , 86 , 89 , 90 , actual_capability ])
146- def _backend1_check (x , backend ):
164+ def _cudnn_check_my_kernel (x , backend ):
147165 return True
148166
149167 @supported_compute_capability ([75 , 80 , 86 , 89 , 90 , actual_capability ])
150- def _backend2_check (x , backend ):
168+ def _cutlass_check_my_kernel (x , backend ):
151169 return True
152170
153171 def _common_check (x , backend ):
154172 # Common requirement: must be 2D
155173 return x .dim () == 2
156174
157175 @backend_requirement (
158- {"backend1 " : _backend1_check , "backend2 " : _backend2_check },
176+ {"cudnn " : _cudnn_check_my_kernel , "cutlass " : _cutlass_check_my_kernel },
159177 common_check = _common_check ,
160178 )
161- def my_kernel (x , backend = "backend1 " ):
179+ def my_kernel (x , backend = "cudnn " ):
162180 return x * 2
163181
164182 x_2d = torch .randn (10 , 10 , device = "cuda" )
165183 x_3d = torch .randn (10 , 10 , 10 , device = "cuda" )
166184
167185 # 2D should work with skip_check
168- result = my_kernel (x_2d , backend = "backend1 " , skip_check = True )
186+ result = my_kernel (x_2d , backend = "cudnn " , skip_check = True )
169187 assert result .shape == x_2d .shape
170188
171189 # 3D should fail validation
172190 with pytest .raises (ValueError , match = "Problem size is not supported" ):
173- my_kernel (x_3d , backend = "backend1 " )
191+ my_kernel (x_3d , backend = "cudnn " )
174192
175193
176194def test_functools_wraps_preserves_metadata ():
0 commit comments