-
Notifications
You must be signed in to change notification settings - Fork 54
Add support for cache-hinted load and store operations #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,241 @@ | ||||||||||
| from llvmlite import ir | ||||||||||
| from numba import types | ||||||||||
| from numba.core import cgutils | ||||||||||
| from numba.core.extending import intrinsic, overload | ||||||||||
| from numba.core.errors import NumbaTypeError | ||||||||||
| from numba.cuda.api_util import normalize_indices | ||||||||||
|
|
||||||||||
| # Docs references: | ||||||||||
| # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld | ||||||||||
| # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#load-functions-using-cache-hints | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def ldca(array, i): | ||||||||||
| """Generate a `ld.global.ca` instruction for element `i` of an array.""" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def ldcg(array, i): | ||||||||||
| """Generate a `ld.global.cg` instruction for element `i` of an array.""" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def ldcs(array, i): | ||||||||||
| """Generate a `ld.global.cs` instruction for element `i` of an array.""" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def ldlu(array, i): | ||||||||||
| """Generate a `ld.global.lu` instruction for element `i` of an array.""" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def ldcv(array, i): | ||||||||||
| """Generate a `ld.global.cv` instruction for element `i` of an array.""" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def stcg(array, i, value): | ||||||||||
| """Generate a `st.global.cg` instruction for element `i` of an array.""" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def stcs(array, i, value): | ||||||||||
| """Generate a `st.global.cs` instruction for element `i` of an array.""" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def stwb(array, i, value): | ||||||||||
| """Generate a `st.global.wb` instruction for element `i` of an array.""" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def stwt(array, i, value): | ||||||||||
| """Generate a `st.global.wt` instruction for element `i` of an array.""" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| # See | ||||||||||
| # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#restricted-use-of-sub-word-sizes | ||||||||||
| # for background on the choice of "r" for 8-bit operands - there is | ||||||||||
| # no constraint for 8-bit operands, but the operand for loads and | ||||||||||
| # stores is permitted to be greater than 8 bits. | ||||||||||
| CONSTRAINT_MAP = { | ||||||||||
| 1: "b", | ||||||||||
| 8: "r", | ||||||||||
| 16: "h", | ||||||||||
| 32: "r", | ||||||||||
| 64: "l", | ||||||||||
| 128: "q" | ||||||||||
| } | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _validate_arguments(instruction, array, index): | ||||||||||
| if not isinstance(array, types.Array): | ||||||||||
| msg = f"{instruction} operates on arrays. Got type {array}" | ||||||||||
| raise NumbaTypeError(msg) | ||||||||||
|
|
||||||||||
| valid_index = False | ||||||||||
|
|
||||||||||
| if isinstance(index, types.Integer): | ||||||||||
| if array.ndim != 1: | ||||||||||
| msg = f"Expected {array.ndim} indices, got a scalar" | ||||||||||
| raise NumbaTypeError(msg) | ||||||||||
| valid_index = True | ||||||||||
|
|
||||||||||
| if isinstance(index, types.UniTuple): | ||||||||||
| if index.count != array.ndim: | ||||||||||
| msg = f"Expected {array.ndim} indices, got {index.count}" | ||||||||||
| raise NumbaTypeError(msg) | ||||||||||
|
|
||||||||||
| if all([isinstance(t, types.Integer) for t in index.dtype]): | ||||||||||
| valid_index = True | ||||||||||
|
|
||||||||||
| if not valid_index: | ||||||||||
| raise NumbaTypeError(f"{index} is not a valid index") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def ld_cache_operator(operator): | ||||||||||
| @intrinsic | ||||||||||
| def impl(typingctx, array, index): | ||||||||||
| _validate_arguments(f"ld{operator}", array, index) | ||||||||||
|
|
||||||||||
| # Need to validate bitwidth | ||||||||||
|
|
||||||||||
| signature = array.dtype(array, index) | ||||||||||
|
|
||||||||||
| def codegen(context, builder, sig, args): | ||||||||||
| array_type, index_type = sig.args | ||||||||||
| loaded_type = context.get_value_type(array_type.dtype) | ||||||||||
| ptr_type = loaded_type.as_pointer() | ||||||||||
| ldcs_type = ir.FunctionType(loaded_type, [ptr_type]) | ||||||||||
|
|
||||||||||
| array, indices = args | ||||||||||
|
|
||||||||||
| index_type, indices = normalize_indices(context, builder, | ||||||||||
| index_type, indices, | ||||||||||
| array_type, | ||||||||||
| array_type.dtype) | ||||||||||
| array_struct = context.make_array(array_type)(context, builder, | ||||||||||
| value=array) | ||||||||||
| ptr = cgutils.get_item_pointer(context, builder, array_type, | ||||||||||
| array_struct, indices, | ||||||||||
| wraparound=True) | ||||||||||
|
|
||||||||||
| bitwidth = array_type.dtype.bitwidth | ||||||||||
| inst = f"ld.global.{operator}.b{bitwidth}" | ||||||||||
| constraints = f"={CONSTRAINT_MAP[bitwidth]},l" | ||||||||||
| ldcs = ir.InlineAsm(ldcs_type, f"{inst} $0, [$1];", constraints) | ||||||||||
| return builder.call(ldcs, [ptr]) | ||||||||||
|
|
||||||||||
| return signature, codegen | ||||||||||
|
|
||||||||||
| return impl | ||||||||||
|
|
||||||||||
|
|
||||||||||
| ldca_intrinsic = ld_cache_operator("ca") | ||||||||||
| ldcg_intrinsic = ld_cache_operator("cg") | ||||||||||
| ldcs_intrinsic = ld_cache_operator("cs") | ||||||||||
| ldlu_intrinsic = ld_cache_operator("lu") | ||||||||||
| ldcv_intrinsic = ld_cache_operator("cv") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def st_cache_operator(operator): | ||||||||||
| @intrinsic | ||||||||||
| def impl(typingctx, array, index, value): | ||||||||||
| _validate_arguments(f"st{operator}", array, index) | ||||||||||
|
|
||||||||||
| # Need to validate bitwidth | ||||||||||
|
|
||||||||||
| signature = types.void(array, index, value) | ||||||||||
|
|
||||||||||
| def codegen(context, builder, sig, args): | ||||||||||
| array_type, index_type, value_type = sig.args | ||||||||||
| stored_type = context.get_value_type(array_type.dtype) | ||||||||||
| ptr_type = stored_type.as_pointer() | ||||||||||
| stcs_type = ir.FunctionType(ir.VoidType(), [ptr_type, stored_type]) | ||||||||||
|
|
||||||||||
| array, indices, value = args | ||||||||||
|
|
||||||||||
| index_type, indices = normalize_indices(context, builder, | ||||||||||
| index_type, indices, | ||||||||||
| array_type, | ||||||||||
| array_type.dtype) | ||||||||||
| array_struct = context.make_array(array_type)(context, builder, | ||||||||||
| value=array) | ||||||||||
| ptr = cgutils.get_item_pointer(context, builder, array_type, | ||||||||||
| array_struct, indices, | ||||||||||
| wraparound=True) | ||||||||||
|
|
||||||||||
| casted_value = context.cast(builder, value, value_type, | ||||||||||
| array_type.dtype) | ||||||||||
|
|
||||||||||
| bitwidth = array_type.dtype.bitwidth | ||||||||||
| inst = f"st.global.{operator}.b{bitwidth}" | ||||||||||
| constraints = f"l,{CONSTRAINT_MAP[bitwidth]},~{{memory}}" | ||||||||||
| stcs = ir.InlineAsm(stcs_type, f"{inst} [$0], $1;", constraints) | ||||||||||
| builder.call(stcs, [ptr, casted_value]) | ||||||||||
|
|
||||||||||
| return signature, codegen | ||||||||||
|
|
||||||||||
| return impl | ||||||||||
|
|
||||||||||
|
|
||||||||||
| stcg_intrinsic = st_cache_operator("cg") | ||||||||||
| stcs_intrinsic = st_cache_operator("cs") | ||||||||||
| stwb_intrinsic = st_cache_operator("wb") | ||||||||||
| stwt_intrinsic = st_cache_operator("wt") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @overload(ldca, target='cuda') | ||||||||||
| def ol_ldca(array, i): | ||||||||||
| def impl(array, i): | ||||||||||
| return ldca_intrinsic(array, i) | ||||||||||
| return impl | ||||||||||
|
Comment on lines
+183
to
+185
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't this just be:
Suggested change
for each of these?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An overload should return a Python function that gets jit-compiled to become the implementation of the function it's overloading. I suspect it will not work to return an intrinsic as an overload implementation, but even if it did, it would feel jarring to me to contract a level of abstraction in the implementation here.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does seem really weird to my eye that a function This feels like a numbaism perhaps that violates basic substitution rules, but this is of course not a blocking comment.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the confusion here comes from reading the code as if it's going to be interpreted by the Python interpreter, rather than seeing it as a form of metaprogramming, which is what functions decorated with An attempt to summarise the pertinent points:
I hope this clarifies things a bit, but for a more complete understanding I can't see a shortcut that avoids working through the low- and high-level extension API documentation for Numba: That said, I do have a couple of worked examples that show the whole flow in one notebook for each of these APIs for the CUDA target, which may also help: They may be a little out-of-date and need a couple of bits updating, but the general flow of them is still relevant. We're using the High-level API here (that Finally - what happens if you implement the suggested change (modulo the typo in the name of the function above)? You will get: cc @kaeun97 as the explanation might be helpful in understanding the PR as a whole. |
||||||||||
|
|
||||||||||
|
|
||||||||||
| @overload(ldcg, target='cuda') | ||||||||||
| def ol_ldcg(array, i): | ||||||||||
| def impl(array, i): | ||||||||||
| return ldcg_intrinsic(array, i) | ||||||||||
| return impl | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @overload(ldcs, target='cuda') | ||||||||||
| def ol_ldcs(array, i): | ||||||||||
| def impl(array, i): | ||||||||||
| return ldcs_intrinsic(array, i) | ||||||||||
| return impl | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @overload(ldlu, target='cuda') | ||||||||||
| def ol_ldlu(array, i): | ||||||||||
| def impl(array, i): | ||||||||||
| return ldlu_intrinsic(array, i) | ||||||||||
| return impl | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @overload(ldcv, target='cuda') | ||||||||||
| def ol_ldcv(array, i): | ||||||||||
| def impl(array, i): | ||||||||||
| return ldcv_intrinsic(array, i) | ||||||||||
| return impl | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @overload(stcg, target='cuda') | ||||||||||
| def ol_stcg(array, i, value): | ||||||||||
| def impl(array, i, value): | ||||||||||
| return stcg_intrinsic(array, i, value) | ||||||||||
| return impl | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @overload(stcs, target='cuda') | ||||||||||
| def ol_stcs(array, i, value): | ||||||||||
| def impl(array, i, value): | ||||||||||
| return stcs_intrinsic(array, i, value) | ||||||||||
| return impl | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @overload(stwb, target='cuda') | ||||||||||
| def ol_stwb(array, i, value): | ||||||||||
| def impl(array, i, value): | ||||||||||
| return stwb_intrinsic(array, i, value) | ||||||||||
| return impl | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @overload(stwt, target='cuda') | ||||||||||
| def ol_stwt(array, i, value): | ||||||||||
| def impl(array, i, value): | ||||||||||
| return stwt_intrinsic(array, i, value) | ||||||||||
| return impl | ||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No reason to create a list if you don't need to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will apply this change in the follow up PR :)