-
Notifications
You must be signed in to change notification settings - Fork 54
Fix Invalid NVVM IR emitted when lowering shfl_sync APIs #231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9be71e7
889a5c3
5b6217f
c412f32
d11f31c
8181ff7
97904a3
8532abc
c5dacbb
f2b5c65
b359f9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -2,7 +2,7 @@ | |||
|
|
||||
| from numba import cuda, types | ||||
| from numba.core import cgutils | ||||
| from numba.core.errors import RequireLiteralValue | ||||
| from numba.core.errors import RequireLiteralValue, TypingError | ||||
| from numba.core.typing import signature | ||||
| from numba.core.extending import overload_attribute, overload_method | ||||
| from numba.cuda import nvvmutils | ||||
|
|
@@ -205,3 +205,174 @@ def syncthreads_or(typingctx, predicate): | |||
| @overload_method(types.Integer, "bit_count", target="cuda") | ||||
| def integer_bit_count(i): | ||||
| return lambda i: cuda.popc(i) | ||||
|
|
||||
|
|
||||
| # ------------------------------------------------------------------------------- | ||||
| # Warp shuffle functions | ||||
| # | ||||
| # References: | ||||
| # | ||||
| # - https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-shuffle-functions | ||||
| # - https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#data-movement | ||||
| # | ||||
| # Notes: | ||||
| # | ||||
| # - The public CUDA C/C++ and Numba Python APIs for these intrinsics use | ||||
| # different names for parameters to the NVVM IR specification. So that we | ||||
| # can correlate the implementation with the documentation, the @intrinsic | ||||
| # API functions map the public API arguments to the NVVM intrinsic | ||||
| # arguments. | ||||
| # - The NVVM IR specification requires some of the parameters (e.g. mode) to be | ||||
| # constants. It's therefore essential that we pass in some values to the | ||||
| # shfl_sync_intrinsic function (e.g. the mode and c values). | ||||
| # - Normally parameters for intrinsic functions in Numba would be given the | ||||
| # same name as used in the API, and would contain a type. However, because we | ||||
| # have to pass in some values and some times (and there is divergence between | ||||
| # the names in the intrinsic documentation and the public APIs) we instead | ||||
| # follow the convention of naming shfl_sync_intrinsic parameters with a | ||||
| # suffix of _type or _value depending on whether they contain a type or a | ||||
| # value. | ||||
|
|
||||
|
|
||||
| @intrinsic | ||||
| def shfl_sync(typingctx, mask, value, src_lane): | ||||
| """ | ||||
| Shuffles ``value`` across the masked warp and returns the value from | ||||
| ``src_lane``. If this is outside the warp, then the given value is | ||||
| returned. | ||||
| """ | ||||
| membermask_type = mask | ||||
| mode_value = 0 | ||||
| a_type = value | ||||
| b_type = src_lane | ||||
| c_value = 0x1F | ||||
| return shfl_sync_intrinsic( | ||||
| typingctx, membermask_type, mode_value, a_type, b_type, c_value | ||||
| ) | ||||
|
|
||||
|
|
||||
| @intrinsic | ||||
| def shfl_up_sync(typingctx, mask, value, delta): | ||||
| """ | ||||
| Shuffles ``value`` across the masked warp and returns the value from | ||||
| ``(laneid - delta)``. If this is outside the warp, then the given value is | ||||
| returned. | ||||
| """ | ||||
| membermask_type = mask | ||||
| mode_value = 1 | ||||
| a_type = value | ||||
| b_type = delta | ||||
| c_value = 0 | ||||
| return shfl_sync_intrinsic( | ||||
| typingctx, membermask_type, mode_value, a_type, b_type, c_value | ||||
| ) | ||||
|
|
||||
|
|
||||
| @intrinsic | ||||
| def shfl_down_sync(typingctx, mask, value, delta): | ||||
| """ | ||||
| Shuffles ``value`` across the masked warp and returns the value from | ||||
| ``(laneid + delta)``. If this is outside the warp, then the given value is | ||||
| returned. | ||||
| """ | ||||
| membermask_type = mask | ||||
| mode_value = 2 | ||||
| a_type = value | ||||
| b_type = delta | ||||
| c_value = 0x1F | ||||
| return shfl_sync_intrinsic( | ||||
| typingctx, membermask_type, mode_value, a_type, b_type, c_value | ||||
| ) | ||||
|
|
||||
|
|
||||
| @intrinsic | ||||
| def shfl_xor_sync(typingctx, mask, value, lane_mask): | ||||
| """ | ||||
| Shuffles ``value`` across the masked warp and returns the value from | ||||
| ``(laneid ^ lane_mask)``. | ||||
| """ | ||||
| membermask_type = mask | ||||
| mode_value = 3 | ||||
| a_type = value | ||||
| b_type = lane_mask | ||||
| c_value = 0x1F | ||||
| return shfl_sync_intrinsic( | ||||
| typingctx, membermask_type, mode_value, a_type, b_type, c_value | ||||
| ) | ||||
|
|
||||
|
|
||||
| def shfl_sync_intrinsic( | ||||
| typingctx, | ||||
| membermask_type, | ||||
| mode_value, | ||||
| a_type, | ||||
| b_type, | ||||
| c_value, | ||||
| ): | ||||
| if a_type not in (types.i4, types.i8, types.f4, types.f8): | ||||
| raise TypingError( | ||||
| "shfl_sync only supports 32- and 64-bit ints and floats" | ||||
| ) | ||||
|
|
||||
| def codegen(context, builder, sig, args): | ||||
| """ | ||||
| The NVVM shfl_sync intrinsic only supports i32, but the CUDA C/C++ | ||||
| intrinsic supports both 32- and 64-bit ints and floats, so for feature | ||||
| parity, i32, i64, f32, and f64 are implemented. Floats by way of | ||||
| bitcasting the float to an int, then shuffling, then bitcasting | ||||
| back.""" | ||||
| membermask, a, b = args | ||||
|
|
||||
| # Types | ||||
| a_type = sig.args[1] | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we don't need this, as per https://github.com/NVIDIA/numba-cuda/pull/231/files#r2070337358:
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Further to our discussion earlier where you suggested not doing this so that it's not captured from the outer function - I'll not commit this suggestion, and merge the PR as-is. |
||||
| return_type = context.get_value_type(sig.return_type) | ||||
| i32 = ir.IntType(32) | ||||
| i64 = ir.IntType(64) | ||||
|
|
||||
| if a_type in types.real_domain: | ||||
| a = builder.bitcast(a, ir.IntType(a_type.bitwidth)) | ||||
|
|
||||
| # NVVM intrinsic definition | ||||
| arg_types = (i32, i32, i32, i32, i32) | ||||
| shfl_return_type = ir.LiteralStructType((i32, ir.IntType(1))) | ||||
| fnty = ir.FunctionType(shfl_return_type, arg_types) | ||||
|
|
||||
| fname = "llvm.nvvm.shfl.sync.i32" | ||||
| shfl_sync = cgutils.get_or_insert_function(builder.module, fnty, fname) | ||||
|
|
||||
| # Intrinsic arguments | ||||
| mode = ir.Constant(i32, mode_value) | ||||
| c = ir.Constant(i32, c_value) | ||||
| membermask = builder.trunc(membermask, i32) | ||||
| b = builder.trunc(b, i32) | ||||
|
|
||||
| if a_type.bitwidth == 32: | ||||
| a = builder.trunc(a, i32) | ||||
| ret = builder.call(shfl_sync, (membermask, mode, a, b, c)) | ||||
| d = builder.extract_value(ret, 0) | ||||
| else: | ||||
| # Handle 64-bit values by shuffling as two 32-bit values and | ||||
| # packing the result into 64 bits. | ||||
|
|
||||
| # Extract high and low parts | ||||
| lo = builder.trunc(a, i32) | ||||
| a_lshr = builder.lshr(a, ir.Constant(i64, 32)) | ||||
| hi = builder.trunc(a_lshr, i32) | ||||
|
|
||||
| # Shuffle individual parts | ||||
| ret_lo = builder.call(shfl_sync, (membermask, mode, lo, b, c)) | ||||
| ret_hi = builder.call(shfl_sync, (membermask, mode, hi, b, c)) | ||||
|
|
||||
| # Combine individual result parts into a 64-bit result | ||||
| d_lo = builder.extract_value(ret_lo, 0) | ||||
| d_hi = builder.extract_value(ret_hi, 0) | ||||
| d_lo_64 = builder.zext(d_lo, i64) | ||||
| d_hi_64 = builder.zext(d_hi, i64) | ||||
| d_shl = builder.shl(d_hi_64, ir.Constant(i64, 32)) | ||||
| d = builder.or_(d_shl, d_lo_64) | ||||
|
|
||||
| return builder.bitcast(d, return_type) | ||||
|
|
||||
| sig = signature(a_type, membermask_type, a_type, b_type) | ||||
|
|
||||
| return sig, codegen | ||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a slight confusion when reading the code, in that
a_typeis both passed in from typing and extracted from loweringargparameter. Later to realize that the firsta_typeis used for typing and the second forlowering. And they are executed in different times.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point - there's no need to get it again from the signature, so I think that definition later of it can be removed - see below. What do you think?