3232
3333
3434@pytest .fixture (scope = "module" )
35- def cuda12_prerequisite_check ():
35+ def cuda12_4_prerequisite_check ():
3636 # binding availability depends on cuda-python version
3737 # and version of underlying CUDA toolkit
3838 _py_major_ver , _ = get_binding_version ()
3939 _driver_ver = handle_return (driver .cuDriverGetVersion ())
40- return _py_major_ver >= 12 and _driver_ver >= 12000
40+ return _py_major_ver >= 12 and _driver_ver >= 12040
4141
4242
4343def test_kernel_attributes_init_disabled ():
@@ -180,12 +180,15 @@ def test_object_code_handle(get_saxpy_object_code):
180180 assert mod .handle is not None
181181
182182
183- def test_saxpy_arguments (get_saxpy_kernel , cuda12_prerequisite_check ):
184- if not cuda12_prerequisite_check :
185- pytest .skip ("Test requires CUDA 12" )
183+ def test_saxpy_arguments (get_saxpy_kernel , cuda12_4_prerequisite_check ):
186184 krn , _ = get_saxpy_kernel
187185
188- assert krn .num_arguments == 5
186+ if cuda12_4_prerequisite_check :
187+ assert krn .num_arguments == 5
188+ else :
189+ with pytest .raises (NotImplementedError ):
190+ _ = krn .num_arguments
191+ return
189192
190193 assert "ParamInfo" in str (type (krn ).arguments_info .fget .__annotations__ )
191194 arg_info = krn .arguments_info
@@ -212,8 +215,8 @@ class ExpectedStruct(ctypes.Structure):
212215
213216@pytest .mark .parametrize ("nargs" , [0 , 1 , 2 , 3 , 16 ])
214217@pytest .mark .parametrize ("c_type_name,c_type" , [("int" , ctypes .c_int ), ("short" , ctypes .c_short )], ids = ["int" , "short" ])
215- def test_num_arguments (init_cuda , nargs , c_type_name , c_type , cuda12_prerequisite_check ):
216- if not cuda12_prerequisite_check :
218+ def test_num_arguments (init_cuda , nargs , c_type_name , c_type , cuda12_4_prerequisite_check ):
219+ if not cuda12_4_prerequisite_check :
217220 pytest .skip ("Test requires CUDA 12" )
218221 args_str = ", " .join ([f"{ c_type_name } p_{ i } " for i in range (nargs )])
219222 src = f"__global__ void foo{ nargs } ({ args_str } ) {{ }}"
@@ -235,8 +238,8 @@ class ExpectedStruct(ctypes.Structure):
235238 assert all ([actual .size == expected .size for actual , expected in zip (arg_info , members )])
236239
237240
238- def test_num_args_error_handling (deinit_all_contexts_function , cuda12_prerequisite_check ):
239- if not cuda12_prerequisite_check :
241+ def test_num_args_error_handling (deinit_all_contexts_function , cuda12_4_prerequisite_check ):
242+ if not cuda12_4_prerequisite_check :
240243 pytest .skip ("Test requires CUDA 12" )
241244 src = "__global__ void foo(int a) { }"
242245 prog = Program (src , code_type = "c++" )
0 commit comments