Skip to content

Commit cf74c40

Browse files
For drivers with version in [12000, 12040) skip tests requiring KernelParamInfo
Except for one test where we check that NotImplementedError is raised.
1 parent cf9424e commit cf74c40

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

cuda_core/tests/test_module.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@
3232

3333

3434
@pytest.fixture(scope="module")
35-
def cuda12_prerequisite_check():
35+
def cuda12_4_prerequisite_check():
3636
# binding availability depends on cuda-python version
3737
# and version of underlying CUDA toolkit
3838
_py_major_ver, _ = get_binding_version()
3939
_driver_ver = handle_return(driver.cuDriverGetVersion())
40-
return _py_major_ver >= 12 and _driver_ver >= 12000
40+
return _py_major_ver >= 12 and _driver_ver >= 12040
4141

4242

4343
def test_kernel_attributes_init_disabled():
@@ -180,12 +180,15 @@ def test_object_code_handle(get_saxpy_object_code):
180180
assert mod.handle is not None
181181

182182

183-
def test_saxpy_arguments(get_saxpy_kernel, cuda12_prerequisite_check):
184-
if not cuda12_prerequisite_check:
185-
pytest.skip("Test requires CUDA 12")
183+
def test_saxpy_arguments(get_saxpy_kernel, cuda12_4_prerequisite_check):
186184
krn, _ = get_saxpy_kernel
187185

188-
assert krn.num_arguments == 5
186+
if cuda12_4_prerequisite_check:
187+
assert krn.num_arguments == 5
188+
else:
189+
with pytest.raises("NotImplementedError"):
190+
_ = krn.num_arguments
191+
return
189192

190193
assert "ParamInfo" in str(type(krn).arguments_info.fget.__annotations__)
191194
arg_info = krn.arguments_info
@@ -212,8 +215,8 @@ class ExpectedStruct(ctypes.Structure):
212215

213216
@pytest.mark.parametrize("nargs", [0, 1, 2, 3, 16])
214217
@pytest.mark.parametrize("c_type_name,c_type", [("int", ctypes.c_int), ("short", ctypes.c_short)], ids=["int", "short"])
215-
def test_num_arguments(init_cuda, nargs, c_type_name, c_type, cuda12_prerequisite_check):
216-
if not cuda12_prerequisite_check:
218+
def test_num_arguments(init_cuda, nargs, c_type_name, c_type, cuda12_4_prerequisite_check):
219+
if not cuda12_4_prerequisite_check:
217220
pytest.skip("Test requires CUDA 12")
218221
args_str = ", ".join([f"{c_type_name} p_{i}" for i in range(nargs)])
219222
src = f"__global__ void foo{nargs}({args_str}) {{ }}"
@@ -235,8 +238,8 @@ class ExpectedStruct(ctypes.Structure):
235238
assert all([actual.size == expected.size for actual, expected in zip(arg_info, members)])
236239

237240

238-
def test_num_args_error_handling(deinit_all_contexts_function, cuda12_prerequisite_check):
239-
if not cuda12_prerequisite_check:
241+
def test_num_args_error_handling(deinit_all_contexts_function, cuda12_4_prerequisite_check):
242+
if not cuda12_4_prerequisite_check:
240243
pytest.skip("Test requires CUDA 12")
241244
src = "__global__ void foo(int a) { }"
242245
prog = Program(src, code_type="c++")

0 commit comments

Comments
 (0)