-
Notifications
You must be signed in to change notification settings - Fork 55
feat: users can pass shared_memory_carveout to @cuda.jit
#642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9368f5d
81d16af
1697fc7
d91f8f4
683e774
17f23dc
cb34ba3
0bb051f
35ce525
373b7c2
e5ba748
1be8209
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -30,6 +30,7 @@ def jit( | |||||||||||||||||||||||||||||||||||||||||
| cache=False, | ||||||||||||||||||||||||||||||||||||||||||
| launch_bounds=None, | ||||||||||||||||||||||||||||||||||||||||||
| lto=None, | ||||||||||||||||||||||||||||||||||||||||||
| shared_memory_carveout=None, | ||||||||||||||||||||||||||||||||||||||||||
| **kws, | ||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -93,6 +94,19 @@ def jit( | |||||||||||||||||||||||||||||||||||||||||
| default when nvjitlink is available, except for kernels where | ||||||||||||||||||||||||||||||||||||||||||
| ``debug=True``. | ||||||||||||||||||||||||||||||||||||||||||
| :type lto: bool | ||||||||||||||||||||||||||||||||||||||||||
| :param shared_memory_carveout: Controls the partitioning of shared memory and L1 | ||||||||||||||||||||||||||||||||||||||||||
| cache on the GPU. Accepts either a string or an integer: | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| - String values: ``"MaxL1"`` (maximize L1 cache), ``"MaxShared"`` | ||||||||||||||||||||||||||||||||||||||||||
| (maximize shared memory), or ``"default"`` (use driver default). | ||||||||||||||||||||||||||||||||||||||||||
| - Integer values: 0-100 representing the percentage of shared | ||||||||||||||||||||||||||||||||||||||||||
| memory to carve out from the unified memory pool, or -1 for | ||||||||||||||||||||||||||||||||||||||||||
| the default carveout preference. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| This parameter is only effective on devices with a unified L1/shared memory | ||||||||||||||||||||||||||||||||||||||||||
| architecture. If unspecified, the CUDA driver uses the default carveout | ||||||||||||||||||||||||||||||||||||||||||
| preference. | ||||||||||||||||||||||||||||||||||||||||||
| :type shared_memory_carveout: str | int | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if link and config.ENABLE_CUDASIM: | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -111,6 +125,9 @@ def jit( | |||||||||||||||||||||||||||||||||||||||||
| msg = _msg_deprecated_signature_arg.format("bind") | ||||||||||||||||||||||||||||||||||||||||||
| raise DeprecationError(msg) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if shared_memory_carveout is not None: | ||||||||||||||||||||||||||||||||||||||||||
| _validate_shared_memory_carveout(shared_memory_carveout) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if isinstance(inline, bool): | ||||||||||||||||||||||||||||||||||||||||||
| DeprecationWarning( | ||||||||||||||||||||||||||||||||||||||||||
| "Passing bool to inline argument is deprecated, please refer to " | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -186,6 +203,7 @@ def _jit(func): | |||||||||||||||||||||||||||||||||||||||||
| targetoptions["extensions"] = extensions | ||||||||||||||||||||||||||||||||||||||||||
| targetoptions["launch_bounds"] = launch_bounds | ||||||||||||||||||||||||||||||||||||||||||
| targetoptions["lto"] = lto | ||||||||||||||||||||||||||||||||||||||||||
| targetoptions["shared_memory_carveout"] = shared_memory_carveout | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| disp = CUDADispatcher(func, targetoptions=targetoptions) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -234,6 +252,7 @@ def autojitwrapper(func): | |||||||||||||||||||||||||||||||||||||||||
| link=link, | ||||||||||||||||||||||||||||||||||||||||||
| cache=cache, | ||||||||||||||||||||||||||||||||||||||||||
| launch_bounds=launch_bounds, | ||||||||||||||||||||||||||||||||||||||||||
| shared_memory_carveout=shared_memory_carveout, | ||||||||||||||||||||||||||||||||||||||||||
| **kws, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -257,6 +276,7 @@ def autojitwrapper(func): | |||||||||||||||||||||||||||||||||||||||||
| targetoptions["extensions"] = extensions | ||||||||||||||||||||||||||||||||||||||||||
| targetoptions["launch_bounds"] = launch_bounds | ||||||||||||||||||||||||||||||||||||||||||
| targetoptions["lto"] = lto | ||||||||||||||||||||||||||||||||||||||||||
| targetoptions["shared_memory_carveout"] = shared_memory_carveout | ||||||||||||||||||||||||||||||||||||||||||
| disp = CUDADispatcher(func_or_sig, targetoptions=targetoptions) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if cache: | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -292,3 +312,20 @@ def declare_device(name, sig, link=None, use_cooperative=False): | |||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| return template.key | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _validate_shared_memory_carveout(carveout): | ||||||||||||||||||||||||||||||||||||||||||
| if isinstance(carveout, str): | ||||||||||||||||||||||||||||||||||||||||||
| valid_strings = ["default", "maxl1", "maxshared"] | ||||||||||||||||||||||||||||||||||||||||||
| if carveout.lower() not in valid_strings: | ||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||
| f"Invalid carveout value: {carveout}. " | ||||||||||||||||||||||||||||||||||||||||||
| f"Must be -1 to 100 or one of {valid_strings}" | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+319
to
+323
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error message displays lowercase strings The validation correctly accepts any case (due to
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+319
to
+324
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error message displays lowercase string options Consider updating the error message to show the capitalized forms that users should actually use:
Suggested change
This way the error message displays the same capitalization as the documentation, while still accepting case-insensitive input. |
||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(carveout, int): | ||||||||||||||||||||||||||||||||||||||||||
| if not (-1 <= carveout <= 100): | ||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Carveout must be between -1 and 100") | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+325
to
+327
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Consider explicitly rejecting boolean types:
Suggested change
Note: The
Comment on lines
+325
to
+327
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The validation accepts boolean values because
To fix this, check for boolean types explicitly before checking for int:
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| raise TypeError( | ||||||||||||||||||||||||||||||||||||||||||
| f"shared_memory_carveout must be str or int, got {type(carveout).__name__}" | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -816,5 +816,67 @@ def test_too_many_launch_bounds(self): | |
| cuda.jit("void()", launch_bounds=launch_bounds)(lambda: None) | ||
|
|
||
|
|
||
| @skip_on_cudasim("Simulator does not support shared memory carveout") | ||
| class TestSharedMemoryCarveout(CUDATestCase): | ||
| def test_shared_memory_carveout_invalid_values(self): | ||
| """Test that invalid carveout values raise appropriate errors""" | ||
| test_cases = [ | ||
| (150, ValueError, "must be between -1 and 100"), | ||
| (-2, ValueError, "must be between -1 and 100"), | ||
| (101, ValueError, "must be between -1 and 100"), | ||
| ("InvalidOption", ValueError, "Invalid carveout value"), | ||
| ] | ||
|
|
||
| for carveout, exc_type, msg_pattern in test_cases: | ||
| with self.subTest(carveout=carveout): | ||
| # without signature | ||
| with self.assertRaisesRegex(exc_type, msg_pattern): | ||
|
Comment on lines
+821
to
+833
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test coverage for invalid values should include boolean types, which are currently incorrectly accepted due to the validation bug in decorators.py (line 325). Consider adding test cases for:
Example addition to test_cases: test_cases = [
(150, ValueError, "must be between -1 and 100"),
(-2, ValueError, "must be between -1 and 100"),
(101, ValueError, "must be between -1 and 100"),
("InvalidOption", ValueError, "Invalid carveout value"),
(True, TypeError, "must be str or int"), # Currently fails - booleans are accepted
(False, TypeError, "must be str or int"), # Currently fails - booleans are accepted
(50.5, TypeError, "must be str or int"),
] |
||
|
|
||
| @cuda.jit(shared_memory_carveout=carveout) | ||
| def add_one(x): | ||
| i = cuda.grid(1) | ||
| if i < len(x): | ||
| x[i] = i + 1 | ||
|
|
||
| # with signature | ||
| with self.assertRaisesRegex(exc_type, msg_pattern): | ||
|
|
||
| @cuda.jit("void(int32[:])", shared_memory_carveout=carveout) | ||
| def add_one_sig(x): | ||
| i = cuda.grid(1) | ||
| if i < len(x): | ||
| x[i] = i + 1 | ||
|
|
||
| def test_shared_memory_carveout_valid_values(self): | ||
| carveout_values = ["MaxL1", "MaxShared", "default", 0, 50, 100, -1] | ||
|
|
||
| x = np.zeros(10, dtype=np.int32) | ||
| expected = np.arange(1, 11) | ||
|
|
||
| for carveout in carveout_values: | ||
| with self.subTest(carveout=carveout): | ||
| # without signature | ||
| @cuda.jit(shared_memory_carveout=carveout) | ||
| def add_one(x): | ||
| i = cuda.grid(1) | ||
| if i < x.size: | ||
| x[i] = i + 1 | ||
|
|
||
| d_x = cuda.to_device(x) | ||
| add_one[1, 10](d_x) | ||
| np.testing.assert_array_equal(d_x.copy_to_host(), expected) | ||
|
|
||
| # with signature | ||
| @cuda.jit("void(int32[:])", shared_memory_carveout=carveout) | ||
| def add_one_sig(x): | ||
| i = cuda.grid(1) | ||
| if i < x.size: | ||
| x[i] = i + 1 | ||
|
|
||
| d_x = cuda.to_device(x) | ||
| add_one_sig[1, 10](d_x) | ||
| np.testing.assert_array_equal(d_x.copy_to_host(), expected) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
shared_memory_carveoutparameter is validated here but doesn't check ifdevice=Trueis also set. Since shared memory carveout only applies to kernel launches (not device functions), using this parameter withdevice=Trueshould raise an error rather than being silently ignored.Consider adding validation:
This provides clearer user feedback when the parameter is misused.