diff --git a/numba_cuda/numba/cuda/decorators.py b/numba_cuda/numba/cuda/decorators.py index 8d8b90817..978fae4f6 100644 --- a/numba_cuda/numba/cuda/decorators.py +++ b/numba_cuda/numba/cuda/decorators.py @@ -30,6 +30,7 @@ def jit( cache=False, launch_bounds=None, lto=None, + shared_memory_carveout=None, **kws, ): """ @@ -93,6 +94,19 @@ def jit( default when nvjitlink is available, except for kernels where ``debug=True``. :type lto: bool + :param shared_memory_carveout: Controls the partitioning of shared memory and L1 + cache on the GPU. Accepts either a string or an integer: + + - String values: ``"MaxL1"`` (maximize L1 cache), ``"MaxShared"`` + (maximize shared memory), or ``"default"`` (use driver default). + - Integer values: 0-100 representing the percentage of shared + memory to carve out from the unified memory pool, or -1 for + the default carveout preference. + + This parameter is only effective on devices with a unified L1/shared memory + architecture. If unspecified, the CUDA driver uses the default carveout + preference. + :type shared_memory_carveout: str | int """ if link and config.ENABLE_CUDASIM: @@ -111,6 +125,9 @@ def jit( msg = _msg_deprecated_signature_arg.format("bind") raise DeprecationError(msg) + if shared_memory_carveout is not None: + _validate_shared_memory_carveout(shared_memory_carveout) + if isinstance(inline, bool): DeprecationWarning( "Passing bool to inline argument is deprecated, please refer to " @@ -186,6 +203,7 @@ def _jit(func): targetoptions["extensions"] = extensions targetoptions["launch_bounds"] = launch_bounds targetoptions["lto"] = lto + targetoptions["shared_memory_carveout"] = shared_memory_carveout disp = CUDADispatcher(func, targetoptions=targetoptions) @@ -234,6 +252,7 @@ def autojitwrapper(func): link=link, cache=cache, launch_bounds=launch_bounds, + shared_memory_carveout=shared_memory_carveout, **kws, ) @@ -257,6 +276,7 @@ def autojitwrapper(func): targetoptions["extensions"] = extensions targetoptions["launch_bounds"] = launch_bounds targetoptions["lto"] = lto + targetoptions["shared_memory_carveout"] = shared_memory_carveout disp = CUDADispatcher(func_or_sig, targetoptions=targetoptions) if cache: @@ -292,3 +312,20 @@ def declare_device(name, sig, link=None, use_cooperative=False): ) return template.key + + +def _validate_shared_memory_carveout(carveout): + if isinstance(carveout, str): + valid_strings = ["default", "maxl1", "maxshared"] + if carveout.lower() not in valid_strings: + raise ValueError( + f"Invalid carveout value: {carveout}. " + f"Must be -1 to 100 or one of {valid_strings}" + ) + elif isinstance(carveout, int): + if not (-1 <= carveout <= 100): + raise ValueError("Carveout must be between -1 and 100") + else: + raise TypeError( + f"shared_memory_carveout must be str or int, got {type(carveout).__name__}" + ) diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index 776b8e0b2..acbe13575 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -116,6 +116,7 @@ def __init__( opt=True, device=False, launch_bounds=None, + shared_memory_carveout=None, ): if device: raise RuntimeError("Cannot compile a device function as a kernel") @@ -170,6 +171,12 @@ def __init__( lib._entry_name = cres.fndesc.llvm_func_name kernel_fixup(kernel, self.debug) nvvm.set_launch_bounds(kernel, launch_bounds) + if shared_memory_carveout is not None: + self.shared_memory_carveout = self._parse_carveout( + shared_memory_carveout + ) + else: + self.shared_memory_carveout = None if not link: link = [] @@ -289,6 +296,7 @@ def _rebuild( lineinfo, call_helper, extensions, + shared_memory_carveout=None, ): """ Rebuild an instance. @@ -307,6 +315,7 @@ def _rebuild( instance.lineinfo = lineinfo instance.call_helper = call_helper instance.extensions = extensions + instance.shared_memory_carveout = shared_memory_carveout return instance def _reduce_states(self): @@ -326,8 +335,15 @@ def _reduce_states(self): lineinfo=self.lineinfo, call_helper=self.call_helper, extensions=self.extensions, + shared_memory_carveout=self.shared_memory_carveout, ) + def _parse_carveout(self, carveout): + if isinstance(carveout, int): + return carveout + carveout_map = {"default": -1, "maxl1": 0, "maxshared": 100} + return carveout_map[str(carveout).lower()] + @module_init_lock def initialize_once(self, mod): if not mod.initialized: @@ -341,6 +357,9 @@ def bind(self): self.initialize_once(cufunc.module) + if self.shared_memory_carveout is not None: + cufunc.set_shared_memory_carveout(self.shared_memory_carveout) + if ( hasattr(self, "target_context") and self.target_context.enable_nrt diff --git a/numba_cuda/numba/cuda/simulator/api.py b/numba_cuda/numba/cuda/simulator/api.py index 7f98aebf7..eb085cb5d 100644 --- a/numba_cuda/numba/cuda/simulator/api.py +++ b/numba_cuda/numba/cuda/simulator/api.py @@ -129,6 +129,7 @@ def jit( boundscheck=None, opt=None, cache=None, + shared_memory_carveout=None, ): # Here for API compatibility if boundscheck: diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py b/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py index 851b4683b..aa837dc1b 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py @@ -816,5 +816,67 @@ def test_too_many_launch_bounds(self): cuda.jit("void()", launch_bounds=launch_bounds)(lambda: None) +@skip_on_cudasim("Simulator does not support shared memory carveout") +class TestSharedMemoryCarveout(CUDATestCase): + def test_shared_memory_carveout_invalid_values(self): + """Test that invalid carveout values raise appropriate errors""" + test_cases = [ + (150, ValueError, "must be between -1 and 100"), + (-2, ValueError, "must be between -1 and 100"), + (101, ValueError, "must be between -1 and 100"), + ("InvalidOption", ValueError, "Invalid carveout value"), + ] + + for carveout, exc_type, msg_pattern in test_cases: + with self.subTest(carveout=carveout): + # without signature + with self.assertRaisesRegex(exc_type, msg_pattern): + + @cuda.jit(shared_memory_carveout=carveout) + def add_one(x): + i = cuda.grid(1) + if i < len(x): + x[i] = i + 1 + + # with signature + with self.assertRaisesRegex(exc_type, msg_pattern): + + @cuda.jit("void(int32[:])", shared_memory_carveout=carveout) + def add_one_sig(x): + i = cuda.grid(1) + if i < len(x): + x[i] = i + 1 + + def test_shared_memory_carveout_valid_values(self): + carveout_values = ["MaxL1", "MaxShared", "default", 0, 50, 100, -1] + + x = np.zeros(10, dtype=np.int32) + expected = np.arange(1, 11) + + for carveout in carveout_values: + with self.subTest(carveout=carveout): + # without signature + @cuda.jit(shared_memory_carveout=carveout) + def add_one(x): + i = cuda.grid(1) + if i < x.size: + x[i] = i + 1 + + d_x = cuda.to_device(x) + add_one[1, 10](d_x) + np.testing.assert_array_equal(d_x.copy_to_host(), expected) + + # with signature + @cuda.jit("void(int32[:])", shared_memory_carveout=carveout) + def add_one_sig(x): + i = cuda.grid(1) + if i < x.size: + x[i] = i + 1 + + d_x = cuda.to_device(x) + add_one_sig[1, 10](d_x) + np.testing.assert_array_equal(d_x.copy_to_host(), expected) + + if __name__ == "__main__": unittest.main()