-
Notifications
You must be signed in to change notification settings - Fork 57
feat: add support for cache-hinted load and store operations #587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
cecd6ea
chore: move work from gmarkall branch
kaeun97 a2b7127
fix: add license and nits
kaeun97 9963469
feat: validate bitwidth
kaeun97 9b5c51a
chore: extract common components
kaeun97 671bcf3
feat: support cpointer
kaeun97 06a1c35
feat: add test for cpointers
kaeun97 f4da9db
fix: add test for bitwidth
kaeun97 94ec9a3
fix: add test for bitwidth
kaeun97 393554a
chore: remove comment
kaeun97 b53cc78
feat: add initial docs
kaeun97 351955e
fix: docs
kaeun97 6eafe13
fix: better and cleaner test
kaeun97 aa5302d
chore: pre-commit
kaeun97 10e7a61
feat: add 2d array test
kaeun97 7f42e38
fix: docs and remove complex types
kaeun97 fe9e8ac
fix: ci simulator failure
kaeun97 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| .. | ||
| SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| SPDX-License-Identifier: BSD-2-Clause | ||
|
|
||
| .. _cache-hints: | ||
|
|
||
| Cache Hints for Memory Operations | ||
| ================================= | ||
|
|
||
| These functions provide explicit control over caching behavior for memory | ||
| operations. They generate PTX instructions with cache policy hints that can | ||
| optimize specific memory access patterns. All functions support arrays or | ||
| pointers with all bitwidths of signed/unsigned integer and floating-point | ||
| types. | ||
|
|
||
| .. seealso:: `Cache Operators | ||
| <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ | ||
| in the PTX ISA documentation. | ||
|
|
||
| .. function:: numba.cuda.ldca(array, i) | ||
|
|
||
| Load element ``i`` from ``array`` with cache-all policy (``ld.global.ca``). This | ||
| is the default caching behavior. | ||
|
|
||
| .. function:: numba.cuda.ldcg(array, i) | ||
|
|
||
| Load element ``i`` from ``array`` with cache-global policy (``ld.global.cg``). | ||
| Useful for data shared across thread blocks. | ||
|
|
||
| .. function:: numba.cuda.ldcs(array, i) | ||
|
|
||
| Load element ``i`` from ``array`` with cache-streaming policy | ||
| (``ld.global.cs``). Optimized for streaming data accessed once. | ||
|
|
||
| .. function:: numba.cuda.ldlu(array, i) | ||
|
|
||
| Load element ``i`` from ``array`` with last-use policy (``ld.global.lu``). | ||
| Indicates data is unlikely to be reused. | ||
|
|
||
| .. function:: numba.cuda.ldcv(array, i) | ||
|
|
||
| Load element ``i`` from ``array`` with cache-volatile policy (``ld.global.cv``). | ||
| Used for volatile data that may change externally. | ||
|
|
||
| .. function:: numba.cuda.stcg(array, i, value) | ||
|
|
||
| Store ``value`` to ``array[i]`` with cache-global policy (``st.global.cg``). | ||
| Useful for data shared across thread blocks. | ||
|
|
||
| .. function:: numba.cuda.stcs(array, i, value) | ||
|
|
||
| Store ``value`` to ``array[i]`` with cache-streaming policy (``st.global.cs``). | ||
| Optimized for streaming writes. | ||
|
|
||
| .. function:: numba.cuda.stwb(array, i, value) | ||
|
|
||
| Store ``value`` to ``array[i]`` with write-back policy (``st.global.wb``). This | ||
| is the default caching behavior. | ||
|
|
||
| .. function:: numba.cuda.stwt(array, i, value) | ||
|
|
||
| Store ``value`` to ``array[i]`` with write-through policy (``st.global.wt``). | ||
| Writes through cache hierarchy to memory. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ Reference documentation | |
|
|
||
| host.rst | ||
| kernel.rst | ||
| cache_hints.rst | ||
| types.rst | ||
| memory.rst | ||
| libdevice.rst | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,287 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: BSD-2-Clause | ||
|
|
||
| from llvmlite import ir | ||
| from numba import types | ||
| from numba.cuda import cgutils | ||
| from numba.cuda.extending import intrinsic, overload | ||
| from numba.cuda.core.errors import NumbaTypeError | ||
| from numba.cuda.api_util import normalize_indices | ||
|
|
||
| # Docs references: | ||
| # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld | ||
| # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#load-functions-using-cache-hints | ||
|
|
||
|
|
||
| def ldca(array, i): | ||
| """Generate a `ld.global.ca` instruction for element `i` of an array.""" | ||
|
|
||
|
|
||
| def ldcg(array, i): | ||
| """Generate a `ld.global.cg` instruction for element `i` of an array.""" | ||
|
|
||
|
|
||
| def ldcs(array, i): | ||
| """Generate a `ld.global.cs` instruction for element `i` of an array.""" | ||
|
|
||
|
|
||
| def ldlu(array, i): | ||
| """Generate a `ld.global.lu` instruction for element `i` of an array.""" | ||
|
|
||
|
|
||
| def ldcv(array, i): | ||
| """Generate a `ld.global.cv` instruction for element `i` of an array.""" | ||
|
|
||
|
|
||
| def stcg(array, i, value): | ||
| """Generate a `st.global.cg` instruction for element `i` of an array.""" | ||
|
|
||
|
|
||
| def stcs(array, i, value): | ||
| """Generate a `st.global.cs` instruction for element `i` of an array.""" | ||
|
|
||
|
|
||
| def stwb(array, i, value): | ||
| """Generate a `st.global.wb` instruction for element `i` of an array.""" | ||
|
|
||
|
|
||
| def stwt(array, i, value): | ||
| """Generate a `st.global.wt` instruction for element `i` of an array.""" | ||
|
|
||
|
|
||
| # See | ||
| # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#restricted-use-of-sub-word-sizes | ||
| # for background on the choice of "r" for 8-bit operands - there is | ||
| # no constraint for 8-bit operands, but the operand for loads and | ||
| # stores is permitted to be greater than 8 bits. | ||
| CONSTRAINT_MAP = {1: "b", 8: "r", 16: "h", 32: "r", 64: "l", 128: "q"} | ||
|
|
||
|
|
||
| def _validate_arguments(instruction, array, index): | ||
| is_array = isinstance(array, types.Array) | ||
| is_pointer = isinstance(array, types.CPointer) | ||
| if not (is_array or is_pointer): | ||
| msg = f"{instruction} operates on arrays or pointers. Got type {array}" | ||
| raise NumbaTypeError(msg) | ||
|
|
||
| valid_index = False | ||
|
|
||
| if isinstance(index, types.Integer): | ||
| if is_array and array.ndim != 1: | ||
| # for pointers, any integer index is valid | ||
| msg = f"Expected {array.ndim} indices, got a scalar" | ||
| raise NumbaTypeError(msg) | ||
| valid_index = True | ||
|
|
||
| if isinstance(index, types.UniTuple): | ||
| if is_pointer: | ||
| msg = f"Pointers only support scalar indexing, got tuple of {index.count}" | ||
| raise NumbaTypeError(msg) | ||
|
|
||
| if index.count != array.ndim: | ||
| msg = f"Expected {array.ndim} indices, got {index.count}" | ||
| raise NumbaTypeError(msg) | ||
|
|
||
| if isinstance(index.dtype, types.Integer): | ||
| valid_index = True | ||
|
|
||
| if not valid_index: | ||
| raise NumbaTypeError(f"{index} is not a valid index") | ||
|
|
||
|
|
||
| def _validate_bitwidth(instruction, array): | ||
| dtype = array.dtype | ||
|
|
||
| if not isinstance(dtype, (types.Integer, types.Float)): | ||
| msg = ( | ||
| f"{instruction} requires array of integer or float type, " | ||
| f"got {dtype}" | ||
| ) | ||
| raise NumbaTypeError(msg) | ||
|
|
||
| bitwidth = dtype.bitwidth | ||
| if bitwidth not in CONSTRAINT_MAP: | ||
| valid_widths = sorted(CONSTRAINT_MAP.keys()) | ||
| msg = ( | ||
| f"{instruction} requires array dtype with bitwidth " | ||
| f"in {valid_widths}, got bitwidth {bitwidth}" | ||
| ) | ||
| raise NumbaTypeError(msg) | ||
|
|
||
|
|
||
| def _get_element_pointer( | ||
| context, builder, index_type, indices, array_type, array | ||
| ): | ||
| if isinstance(array_type, types.CPointer): | ||
| return builder.gep(array, [indices]) | ||
| else: | ||
| index_type, indices = normalize_indices( | ||
| context, | ||
| builder, | ||
| index_type, | ||
| indices, | ||
| array_type, | ||
| array_type.dtype, | ||
| ) | ||
| array_struct = context.make_array(array_type)( | ||
| context, builder, value=array | ||
| ) | ||
| return cgutils.get_item_pointer( | ||
| context, | ||
| builder, | ||
| array_type, | ||
| array_struct, | ||
| indices, | ||
| wraparound=True, | ||
| ) | ||
|
|
||
|
|
||
| def ld_cache_operator(operator): | ||
| @intrinsic | ||
| def impl(typingctx, array, index): | ||
| _validate_arguments(f"ld{operator}", array, index) | ||
| _validate_bitwidth(f"ld{operator}", array) | ||
|
|
||
| signature = array.dtype(array, index) | ||
|
|
||
| def codegen(context, builder, sig, args): | ||
| array_type, index_type = sig.args | ||
| loaded_type = context.get_value_type(array_type.dtype) | ||
| ptr_type = loaded_type.as_pointer() | ||
| ldcs_type = ir.FunctionType(loaded_type, [ptr_type]) | ||
|
|
||
| array, indices = args | ||
|
|
||
| ptr = _get_element_pointer( | ||
| context, builder, index_type, indices, array_type, array | ||
| ) | ||
|
|
||
| bitwidth = array_type.dtype.bitwidth | ||
| inst = f"ld.global.{operator}.b{bitwidth}" | ||
| constraints = f"={CONSTRAINT_MAP[bitwidth]},l" | ||
| ldcs = ir.InlineAsm(ldcs_type, f"{inst} $0, [$1];", constraints) | ||
| return builder.call(ldcs, [ptr]) | ||
|
|
||
| return signature, codegen | ||
|
|
||
| return impl | ||
|
|
||
|
|
||
| ldca_intrinsic = ld_cache_operator("ca") | ||
| ldcg_intrinsic = ld_cache_operator("cg") | ||
| ldcs_intrinsic = ld_cache_operator("cs") | ||
| ldlu_intrinsic = ld_cache_operator("lu") | ||
| ldcv_intrinsic = ld_cache_operator("cv") | ||
|
|
||
|
|
||
| def st_cache_operator(operator): | ||
| @intrinsic | ||
| def impl(typingctx, array, index, value): | ||
| _validate_arguments(f"st{operator}", array, index) | ||
| _validate_bitwidth(f"st{operator}", array) | ||
|
|
||
| signature = types.void(array, index, value) | ||
|
|
||
| def codegen(context, builder, sig, args): | ||
| array_type, index_type, value_type = sig.args | ||
| stored_type = context.get_value_type(array_type.dtype) | ||
| ptr_type = stored_type.as_pointer() | ||
| stcs_type = ir.FunctionType(ir.VoidType(), [ptr_type, stored_type]) | ||
|
|
||
| array, indices, value = args | ||
|
|
||
| ptr = _get_element_pointer( | ||
| context, builder, index_type, indices, array_type, array | ||
| ) | ||
|
|
||
| casted_value = context.cast( | ||
| builder, value, value_type, array_type.dtype | ||
| ) | ||
|
|
||
| bitwidth = array_type.dtype.bitwidth | ||
| inst = f"st.global.{operator}.b{bitwidth}" | ||
| constraints = f"l,{CONSTRAINT_MAP[bitwidth]},~{{memory}}" | ||
| stcs = ir.InlineAsm(stcs_type, f"{inst} [$0], $1;", constraints) | ||
| builder.call(stcs, [ptr, casted_value]) | ||
|
|
||
| return signature, codegen | ||
|
|
||
| return impl | ||
|
|
||
|
|
||
| stcg_intrinsic = st_cache_operator("cg") | ||
| stcs_intrinsic = st_cache_operator("cs") | ||
| stwb_intrinsic = st_cache_operator("wb") | ||
| stwt_intrinsic = st_cache_operator("wt") | ||
|
|
||
|
|
||
| @overload(ldca, target="cuda") | ||
| def ol_ldca(array, i): | ||
| def impl(array, i): | ||
| return ldca_intrinsic(array, i) | ||
|
|
||
| return impl | ||
|
|
||
|
|
||
| @overload(ldcg, target="cuda") | ||
| def ol_ldcg(array, i): | ||
| def impl(array, i): | ||
| return ldcg_intrinsic(array, i) | ||
|
|
||
| return impl | ||
|
|
||
|
|
||
| @overload(ldcs, target="cuda") | ||
| def ol_ldcs(array, i): | ||
| def impl(array, i): | ||
| return ldcs_intrinsic(array, i) | ||
|
|
||
| return impl | ||
|
|
||
|
|
||
| @overload(ldlu, target="cuda") | ||
| def ol_ldlu(array, i): | ||
| def impl(array, i): | ||
| return ldlu_intrinsic(array, i) | ||
|
|
||
| return impl | ||
|
|
||
|
|
||
| @overload(ldcv, target="cuda") | ||
| def ol_ldcv(array, i): | ||
| def impl(array, i): | ||
| return ldcv_intrinsic(array, i) | ||
|
|
||
| return impl | ||
|
|
||
|
|
||
| @overload(stcg, target="cuda") | ||
| def ol_stcg(array, i, value): | ||
| def impl(array, i, value): | ||
| return stcg_intrinsic(array, i, value) | ||
|
|
||
| return impl | ||
|
|
||
|
|
||
| @overload(stcs, target="cuda") | ||
| def ol_stcs(array, i, value): | ||
| def impl(array, i, value): | ||
| return stcs_intrinsic(array, i, value) | ||
|
|
||
| return impl | ||
|
|
||
|
|
||
| @overload(stwb, target="cuda") | ||
| def ol_stwb(array, i, value): | ||
| def impl(array, i, value): | ||
| return stwb_intrinsic(array, i, value) | ||
|
|
||
| return impl | ||
|
|
||
|
|
||
| @overload(stwt, target="cuda") | ||
| def ol_stwt(array, i, value): | ||
| def impl(array, i, value): | ||
| return stwt_intrinsic(array, i, value) | ||
|
|
||
| return impl |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.