Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions numba_cuda/numba/cuda/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def jit(
cache=False,
launch_bounds=None,
lto=None,
shared_memory_carveout=None,
**kws,
):
"""
Expand Down Expand Up @@ -93,6 +94,19 @@ def jit(
default when nvjitlink is available, except for kernels where
``debug=True``.
:type lto: bool
:param shared_memory_carveout: Controls the partitioning of shared memory and L1
cache on the GPU. Accepts either a string or an integer:

- String values: ``"MaxL1"`` (maximize L1 cache), ``"MaxShared"``
(maximize shared memory), or ``"default"`` (use driver default).
- Integer values: 0-100 representing the percentage of shared
memory to carve out from the unified memory pool, or -1 for
the default carveout preference.

This parameter is only effective on devices with a unified L1/shared memory
architecture. If unspecified, the CUDA driver uses the default carveout
preference.
:type shared_memory_carveout: str | int
"""

if link and config.ENABLE_CUDASIM:
Expand All @@ -111,6 +125,9 @@ def jit(
msg = _msg_deprecated_signature_arg.format("bind")
raise DeprecationError(msg)

if shared_memory_carveout is not None:
_validate_shared_memory_carveout(shared_memory_carveout)
Comment on lines +128 to +129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shared_memory_carveout parameter is validated here but doesn't check if device=True is also set. Since shared memory carveout only applies to kernel launches (not device functions), using this parameter with device=True should raise an error rather than being silently ignored.

Consider adding validation:

Suggested change
if shared_memory_carveout is not None:
_validate_shared_memory_carveout(shared_memory_carveout)
if shared_memory_carveout is not None:
if device:
raise ValueError(
"shared_memory_carveout cannot be used with device=True. "
"This parameter only applies to kernels, not device functions."
)
_validate_shared_memory_carveout(shared_memory_carveout)

This provides clearer user feedback when the parameter is misused.


if isinstance(inline, bool):
DeprecationWarning(
"Passing bool to inline argument is deprecated, please refer to "
Expand Down Expand Up @@ -186,6 +203,7 @@ def _jit(func):
targetoptions["extensions"] = extensions
targetoptions["launch_bounds"] = launch_bounds
targetoptions["lto"] = lto
targetoptions["shared_memory_carveout"] = shared_memory_carveout

disp = CUDADispatcher(func, targetoptions=targetoptions)

Expand Down Expand Up @@ -234,6 +252,7 @@ def autojitwrapper(func):
link=link,
cache=cache,
launch_bounds=launch_bounds,
shared_memory_carveout=shared_memory_carveout,
**kws,
)

Expand All @@ -257,6 +276,7 @@ def autojitwrapper(func):
targetoptions["extensions"] = extensions
targetoptions["launch_bounds"] = launch_bounds
targetoptions["lto"] = lto
targetoptions["shared_memory_carveout"] = shared_memory_carveout
disp = CUDADispatcher(func_or_sig, targetoptions=targetoptions)

if cache:
Expand Down Expand Up @@ -292,3 +312,20 @@ def declare_device(name, sig, link=None, use_cooperative=False):
)

return template.key


def _validate_shared_memory_carveout(carveout):
if isinstance(carveout, str):
valid_strings = ["default", "maxl1", "maxshared"]
if carveout.lower() not in valid_strings:
raise ValueError(
f"Invalid carveout value: {carveout}. "
f"Must be -1 to 100 or one of {valid_strings}"
Comment on lines +319 to +323
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message displays lowercase strings ["default", "maxl1", "maxshared"], but the documentation (lines 100-101) shows capitalized forms "MaxL1", "MaxShared", "default". This inconsistency will confuse users when they see an error.

The validation correctly accepts any case (due to carveout.lower() on line 320), but the error message should match what's documented.

Suggested change
valid_strings = ["default", "maxl1", "maxshared"]
if carveout.lower() not in valid_strings:
raise ValueError(
f"Invalid carveout value: {carveout}. "
f"Must be -1 to 100 or one of {valid_strings}"
valid_strings = ["default", "maxl1", "maxshared"]
if carveout.lower() not in valid_strings:
raise ValueError(
f"Invalid carveout value: {carveout}. "
f"Must be -1 to 100 or one of ['default', 'MaxL1', 'MaxShared']"
)

)
Comment on lines +319 to +324
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message displays lowercase string options ["default", "maxl1", "maxshared"], but the documentation (lines 100-101) and tests (test_dispatcher.py line 851) use the capitalized forms "MaxL1", "MaxShared", "default". This inconsistency will confuse users who follow the documentation.

Consider updating the error message to show the capitalized forms that users should actually use:

Suggested change
valid_strings = ["default", "maxl1", "maxshared"]
if carveout.lower() not in valid_strings:
raise ValueError(
f"Invalid carveout value: {carveout}. "
f"Must be -1 to 100 or one of {valid_strings}"
)
valid_strings = ["default", "MaxL1", "MaxShared"]
if carveout.lower() not in [s.lower() for s in valid_strings]:
raise ValueError(
f"Invalid carveout value: {carveout}. "
f"Must be -1 to 100 or one of {valid_strings}"
)

This way the error message displays the same capitalization as the documentation, while still accepting case-insensitive input.

elif isinstance(carveout, int):
if not (-1 <= carveout <= 100):
raise ValueError("Carveout must be between -1 and 100")
Comment on lines +325 to +327
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The isinstance(carveout, int) check will also accept boolean values since bool is a subclass of int in Python. This means shared_memory_carveout=True would be treated as 1 and shared_memory_carveout=False as 0, which is likely unintended behavior.

Consider explicitly rejecting boolean types:

Suggested change
elif isinstance(carveout, int):
if not (-1 <= carveout <= 100):
raise ValueError("Carveout must be between -1 and 100")
elif isinstance(carveout, bool):
raise TypeError(
f"shared_memory_carveout must be str or int, got {type(carveout).__name__}"
)
elif isinstance(carveout, int):
if not (-1 <= carveout <= 100):
raise ValueError("Carveout must be between -1 and 100")

Note: The bool check must come before the int check since bool is a subclass of int.

Comment on lines +325 to +327
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation accepts boolean values because bool is a subclass of int in Python. This means shared_memory_carveout=True and shared_memory_carveout=False are incorrectly accepted.

True evaluates to 1 and False evaluates to 0, both within the valid range [-1, 100], but this is likely unintended behavior that could confuse users.

To fix this, check for boolean types explicitly before checking for int:

Suggested change
elif isinstance(carveout, int):
if not (-1 <= carveout <= 100):
raise ValueError("Carveout must be between -1 and 100")
elif isinstance(carveout, bool):
raise TypeError(
f"shared_memory_carveout must be str or int, got {type(carveout).__name__}"
)
elif isinstance(carveout, int):
if not (-1 <= carveout <= 100):
raise ValueError("Carveout must be between -1 and 100")

else:
raise TypeError(
f"shared_memory_carveout must be str or int, got {type(carveout).__name__}"
)
19 changes: 19 additions & 0 deletions numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
opt=True,
device=False,
launch_bounds=None,
shared_memory_carveout=None,
):
if device:
raise RuntimeError("Cannot compile a device function as a kernel")
Expand Down Expand Up @@ -170,6 +171,12 @@ def __init__(
lib._entry_name = cres.fndesc.llvm_func_name
kernel_fixup(kernel, self.debug)
nvvm.set_launch_bounds(kernel, launch_bounds)
if shared_memory_carveout is not None:
self.shared_memory_carveout = self._parse_carveout(
shared_memory_carveout
)
else:
self.shared_memory_carveout = None

if not link:
link = []
Expand Down Expand Up @@ -289,6 +296,7 @@ def _rebuild(
lineinfo,
call_helper,
extensions,
shared_memory_carveout=None,
):
"""
Rebuild an instance.
Expand All @@ -307,6 +315,7 @@ def _rebuild(
instance.lineinfo = lineinfo
instance.call_helper = call_helper
instance.extensions = extensions
instance.shared_memory_carveout = shared_memory_carveout
return instance

def _reduce_states(self):
Expand All @@ -326,8 +335,15 @@ def _reduce_states(self):
lineinfo=self.lineinfo,
call_helper=self.call_helper,
extensions=self.extensions,
shared_memory_carveout=self.shared_memory_carveout,
)

def _parse_carveout(self, carveout):
if isinstance(carveout, int):
return carveout
carveout_map = {"default": -1, "maxl1": 0, "maxshared": 100}
return carveout_map[str(carveout).lower()]

@module_init_lock
def initialize_once(self, mod):
if not mod.initialized:
Expand All @@ -341,6 +357,9 @@ def bind(self):

self.initialize_once(cufunc.module)

if self.shared_memory_carveout is not None:
cufunc.set_shared_memory_carveout(self.shared_memory_carveout)

if (
hasattr(self, "target_context")
and self.target_context.enable_nrt
Expand Down
1 change: 1 addition & 0 deletions numba_cuda/numba/cuda/simulator/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def jit(
boundscheck=None,
opt=None,
cache=None,
shared_memory_carveout=None,
):
# Here for API compatibility
if boundscheck:
Expand Down
62 changes: 62 additions & 0 deletions numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,5 +816,67 @@ def test_too_many_launch_bounds(self):
cuda.jit("void()", launch_bounds=launch_bounds)(lambda: None)


@skip_on_cudasim("Simulator does not support shared memory carveout")
class TestSharedMemoryCarveout(CUDATestCase):
def test_shared_memory_carveout_invalid_values(self):
"""Test that invalid carveout values raise appropriate errors"""
test_cases = [
(150, ValueError, "must be between -1 and 100"),
(-2, ValueError, "must be between -1 and 100"),
(101, ValueError, "must be between -1 and 100"),
("InvalidOption", ValueError, "Invalid carveout value"),
]

for carveout, exc_type, msg_pattern in test_cases:
with self.subTest(carveout=carveout):
# without signature
with self.assertRaisesRegex(exc_type, msg_pattern):
Comment on lines +821 to +833
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test coverage for invalid values should include boolean types, which are currently incorrectly accepted due to the validation bug in decorators.py (line 325).

Consider adding test cases for:

  • Boolean values: True, False (currently accepted but shouldn't be)
  • Float values: 50.5, 0.0 (should be rejected with TypeError)
  • Other invalid types: [], {}, () (should be rejected with TypeError)

Example addition to test_cases:

test_cases = [
    (150, ValueError, "must be between -1 and 100"),
    (-2, ValueError, "must be between -1 and 100"),
    (101, ValueError, "must be between -1 and 100"),
    ("InvalidOption", ValueError, "Invalid carveout value"),
    (True, TypeError, "must be str or int"),  # Currently fails - booleans are accepted
    (False, TypeError, "must be str or int"),  # Currently fails - booleans are accepted
    (50.5, TypeError, "must be str or int"),
]


@cuda.jit(shared_memory_carveout=carveout)
def add_one(x):
i = cuda.grid(1)
if i < len(x):
x[i] = i + 1

# with signature
with self.assertRaisesRegex(exc_type, msg_pattern):

@cuda.jit("void(int32[:])", shared_memory_carveout=carveout)
def add_one_sig(x):
i = cuda.grid(1)
if i < len(x):
x[i] = i + 1

def test_shared_memory_carveout_valid_values(self):
carveout_values = ["MaxL1", "MaxShared", "default", 0, 50, 100, -1]

x = np.zeros(10, dtype=np.int32)
expected = np.arange(1, 11)

for carveout in carveout_values:
with self.subTest(carveout=carveout):
# without signature
@cuda.jit(shared_memory_carveout=carveout)
def add_one(x):
i = cuda.grid(1)
if i < x.size:
x[i] = i + 1

d_x = cuda.to_device(x)
add_one[1, 10](d_x)
np.testing.assert_array_equal(d_x.copy_to_host(), expected)

# with signature
@cuda.jit("void(int32[:])", shared_memory_carveout=carveout)
def add_one_sig(x):
i = cuda.grid(1)
if i < x.size:
x[i] = i + 1

d_x = cuda.to_device(x)
add_one_sig[1, 10](d_x)
np.testing.assert_array_equal(d_x.copy_to_host(), expected)


if __name__ == "__main__":
unittest.main()
Loading