Skip to content
42 changes: 0 additions & 42 deletions numba_cuda/numba/cuda/cudadecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,45 +100,6 @@ class Cuda_syncwarp(ConcreteTemplate):
cases = [signature(types.none), signature(types.none, types.i4)]


@register
class Cuda_shfl_sync_intrinsic(ConcreteTemplate):
key = cuda.shfl_sync_intrinsic
cases = [
signature(
types.Tuple((types.i4, types.b1)),
types.i4,
types.i4,
types.i4,
types.i4,
types.i4,
),
signature(
types.Tuple((types.i8, types.b1)),
types.i4,
types.i4,
types.i8,
types.i4,
types.i4,
),
signature(
types.Tuple((types.f4, types.b1)),
types.i4,
types.i4,
types.f4,
types.i4,
types.i4,
),
signature(
types.Tuple((types.f8, types.b1)),
types.i4,
types.i4,
types.f8,
types.i4,
types.i4,
),
]


@register
class Cuda_vote_sync_intrinsic(ConcreteTemplate):
key = cuda.vote_sync_intrinsic
Expand Down Expand Up @@ -815,9 +776,6 @@ def resolve_threadfence_system(self, mod):
def resolve_syncwarp(self, mod):
return types.Function(Cuda_syncwarp)

def resolve_shfl_sync_intrinsic(self, mod):
return types.Function(Cuda_shfl_sync_intrinsic)

def resolve_vote_sync_intrinsic(self, mod):
return types.Function(Cuda_vote_sync_intrinsic)

Expand Down
63 changes: 0 additions & 63 deletions numba_cuda/numba/cuda/cudaimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,69 +204,6 @@ def ptx_syncwarp_mask(context, builder, sig, args):
return context.get_dummy_value()


@lower(
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i4, types.i4, types.i4
)
@lower(
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i8, types.i4, types.i4
)
@lower(
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f4, types.i4, types.i4
)
@lower(
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f8, types.i4, types.i4
)
def ptx_shfl_sync_i32(context, builder, sig, args):
"""
The NVVM intrinsic for shfl only supports i32, but the cuda intrinsic
function supports both 32 and 64 bit ints and floats, so for feature parity,
i64, f32, and f64 are implemented. Floats by way of bitcasting the float to
an int, then shuffling, then bitcasting back. And 64-bit values by packing
them into 2 32bit values, shuffling thoose, and then packing back together.
"""
mask, mode, value, index, clamp = args
value_type = sig.args[2]
if value_type in types.real_domain:
value = builder.bitcast(value, ir.IntType(value_type.bitwidth))
fname = "llvm.nvvm.shfl.sync.i32"
lmod = builder.module
fnty = ir.FunctionType(
ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
(
ir.IntType(32),
ir.IntType(32),
ir.IntType(32),
ir.IntType(32),
ir.IntType(32),
),
)
func = cgutils.get_or_insert_function(lmod, fnty, fname)
if value_type.bitwidth == 32:
ret = builder.call(func, (mask, mode, value, index, clamp))
if value_type == types.float32:
rv = builder.extract_value(ret, 0)
pred = builder.extract_value(ret, 1)
fv = builder.bitcast(rv, ir.FloatType())
ret = cgutils.make_anonymous_struct(builder, (fv, pred))
else:
value1 = builder.trunc(value, ir.IntType(32))
value_lshr = builder.lshr(value, context.get_constant(types.i8, 32))
value2 = builder.trunc(value_lshr, ir.IntType(32))
ret1 = builder.call(func, (mask, mode, value1, index, clamp))
ret2 = builder.call(func, (mask, mode, value2, index, clamp))
rv1 = builder.extract_value(ret1, 0)
rv2 = builder.extract_value(ret2, 0)
pred = builder.extract_value(ret1, 1)
rv1_64 = builder.zext(rv1, ir.IntType(64))
rv2_64 = builder.zext(rv2, ir.IntType(64))
rv_shl = builder.shl(rv2_64, context.get_constant(types.i8, 32))
rv = builder.or_(rv_shl, rv1_64)
if value_type == types.float64:
rv = builder.bitcast(rv, ir.DoubleType())
ret = cgutils.make_anonymous_struct(builder, (rv, pred))
return ret


@lower(stubs.vote_sync_intrinsic, types.i4, types.i4, types.boolean)
def ptx_vote_sync(context, builder, sig, args):
fname = "llvm.nvvm.vote.sync"
Expand Down
9 changes: 4 additions & 5 deletions numba_cuda/numba/cuda/device_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
local,
const,
atomic,
shfl_sync_intrinsic,
vote_sync_intrinsic,
match_any_sync,
match_all_sync,
Expand All @@ -40,6 +39,10 @@
syncthreads_and,
syncthreads_count,
syncthreads_or,
shfl_sync,
shfl_up_sync,
shfl_down_sync,
shfl_xor_sync,
)
from .cudadrv.error import CudaSupportError
from numba.cuda.cudadrv.driver import (
Expand Down Expand Up @@ -68,10 +71,6 @@
any_sync,
eq_sync,
ballot_sync,
shfl_sync,
shfl_up_sync,
shfl_down_sync,
shfl_xor_sync,
)

from .kernels import reduction
Expand Down
39 changes: 0 additions & 39 deletions numba_cuda/numba/cuda/intrinsic_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,42 +36,3 @@ def ballot_sync(mask, predicate):
and are within the given mask.
"""
return numba.cuda.vote_sync_intrinsic(mask, 3, predicate)[0]


@jit(device=True)
def shfl_sync(mask, value, src_lane):
"""
Shuffles value across the masked warp and returns the value
from src_lane. If this is outside the warp, then the
given value is returned.
"""
return numba.cuda.shfl_sync_intrinsic(mask, 0, value, src_lane, 0x1F)[0]


@jit(device=True)
def shfl_up_sync(mask, value, delta):
"""
Shuffles value across the masked warp and returns the value
from (laneid - delta). If this is outside the warp, then the
given value is returned.
"""
return numba.cuda.shfl_sync_intrinsic(mask, 1, value, delta, 0)[0]


@jit(device=True)
def shfl_down_sync(mask, value, delta):
"""
Shuffles value across the masked warp and returns the value
from (laneid + delta). If this is outside the warp, then the
given value is returned.
"""
return numba.cuda.shfl_sync_intrinsic(mask, 2, value, delta, 0x1F)[0]


@jit(device=True)
def shfl_xor_sync(mask, value, lane_mask):
"""
Shuffles value across the masked warp and returns the value
from (laneid ^ lane_mask).
"""
return numba.cuda.shfl_sync_intrinsic(mask, 3, value, lane_mask, 0x1F)[0]
173 changes: 172 additions & 1 deletion numba_cuda/numba/cuda/intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from numba import cuda, types
from numba.core import cgutils
from numba.core.errors import RequireLiteralValue
from numba.core.errors import RequireLiteralValue, TypingError
from numba.core.typing import signature
from numba.core.extending import overload_attribute, overload_method
from numba.cuda import nvvmutils
Expand Down Expand Up @@ -205,3 +205,174 @@ def syncthreads_or(typingctx, predicate):
@overload_method(types.Integer, "bit_count", target="cuda")
def integer_bit_count(i):
return lambda i: cuda.popc(i)


# -------------------------------------------------------------------------------
# Warp shuffle functions
#
# References:
#
# - https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-shuffle-functions
# - https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#data-movement
#
# Notes:
#
# - The public CUDA C/C++ and Numba Python APIs for these intrinsics use
# different names for parameters to the NVVM IR specification. So that we
# can correlate the implementation with the documentation, the @intrinsic
# API functions map the public API arguments to the NVVM intrinsic
# arguments.
# - The NVVM IR specification requires some of the parameters (e.g. mode) to be
# constants. It's therefore essential that we pass in some values to the
# shfl_sync_intrinsic function (e.g. the mode and c values).
# - Normally parameters for intrinsic functions in Numba would be given the
# same name as used in the API, and would contain a type. However, because we
# have to pass in some values and some times (and there is divergence between
# the names in the intrinsic documentation and the public APIs) we instead
# follow the convention of naming shfl_sync_intrinsic parameters with a
# suffix of _type or _value depending on whether they contain a type or a
# value.


@intrinsic
def shfl_sync(typingctx, mask, value, src_lane):
"""
Shuffles ``value`` across the masked warp and returns the value from
``src_lane``. If this is outside the warp, then the given value is
returned.
"""
membermask_type = mask
mode_value = 0
a_type = value
b_type = src_lane
c_value = 0x1F
return shfl_sync_intrinsic(
typingctx, membermask_type, mode_value, a_type, b_type, c_value
)


@intrinsic
def shfl_up_sync(typingctx, mask, value, delta):
"""
Shuffles ``value`` across the masked warp and returns the value from
``(laneid - delta)``. If this is outside the warp, then the given value is
returned.
"""
membermask_type = mask
mode_value = 1
a_type = value
b_type = delta
c_value = 0
return shfl_sync_intrinsic(
typingctx, membermask_type, mode_value, a_type, b_type, c_value
)


@intrinsic
def shfl_down_sync(typingctx, mask, value, delta):
"""
Shuffles ``value`` across the masked warp and returns the value from
``(laneid + delta)``. If this is outside the warp, then the given value is
returned.
"""
membermask_type = mask
mode_value = 2
a_type = value
b_type = delta
c_value = 0x1F
return shfl_sync_intrinsic(
typingctx, membermask_type, mode_value, a_type, b_type, c_value
)


@intrinsic
def shfl_xor_sync(typingctx, mask, value, lane_mask):
"""
Shuffles ``value`` across the masked warp and returns the value from
``(laneid ^ lane_mask)``.
"""
membermask_type = mask
mode_value = 3
a_type = value
b_type = lane_mask
c_value = 0x1F
return shfl_sync_intrinsic(
typingctx, membermask_type, mode_value, a_type, b_type, c_value
)


def shfl_sync_intrinsic(
typingctx,
membermask_type,
mode_value,
a_type,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a slight confusion when reading the code, in that a_type is both passed in from typing and extracted from lowering arg parameter. Later to realize that the first a_type is used for typing and the second for lowering. And they are executed in different times.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - there's no need to get it again from the signature, so I think that definition later of it can be removed - see below. What do you think?

b_type,
c_value,
):
if a_type not in (types.i4, types.i8, types.f4, types.f8):
raise TypingError(
"shfl_sync only supports 32- and 64-bit ints and floats"
)

def codegen(context, builder, sig, args):
"""
The NVVM shfl_sync intrinsic only supports i32, but the CUDA C/C++
intrinsic supports both 32- and 64-bit ints and floats, so for feature
parity, i32, i64, f32, and f64 are implemented. Floats by way of
bitcasting the float to an int, then shuffling, then bitcasting
back."""
membermask, a, b = args

# Types
a_type = sig.args[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this, as per https://github.com/NVIDIA/numba-cuda/pull/231/files#r2070337358:

Suggested change
a_type = sig.args[1]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further to our discussion earlier where you suggested not doing this so that it's not captured from the outer function - I'll not commit this suggestion, and merge the PR as-is.

return_type = context.get_value_type(sig.return_type)
i32 = ir.IntType(32)
i64 = ir.IntType(64)

if a_type in types.real_domain:
a = builder.bitcast(a, ir.IntType(a_type.bitwidth))

# NVVM intrinsic definition
arg_types = (i32, i32, i32, i32, i32)
shfl_return_type = ir.LiteralStructType((i32, ir.IntType(1)))
fnty = ir.FunctionType(shfl_return_type, arg_types)

fname = "llvm.nvvm.shfl.sync.i32"
shfl_sync = cgutils.get_or_insert_function(builder.module, fnty, fname)

# Intrinsic arguments
mode = ir.Constant(i32, mode_value)
c = ir.Constant(i32, c_value)
membermask = builder.trunc(membermask, i32)
b = builder.trunc(b, i32)

if a_type.bitwidth == 32:
a = builder.trunc(a, i32)
ret = builder.call(shfl_sync, (membermask, mode, a, b, c))
d = builder.extract_value(ret, 0)
else:
# Handle 64-bit values by shuffling as two 32-bit values and
# packing the result into 64 bits.

# Extract high and low parts
lo = builder.trunc(a, i32)
a_lshr = builder.lshr(a, ir.Constant(i64, 32))
hi = builder.trunc(a_lshr, i32)

# Shuffle individual parts
ret_lo = builder.call(shfl_sync, (membermask, mode, lo, b, c))
ret_hi = builder.call(shfl_sync, (membermask, mode, hi, b, c))

# Combine individual result parts into a 64-bit result
d_lo = builder.extract_value(ret_lo, 0)
d_hi = builder.extract_value(ret_hi, 0)
d_lo_64 = builder.zext(d_lo, i64)
d_hi_64 = builder.zext(d_hi, i64)
d_shl = builder.shl(d_hi_64, ir.Constant(i64, 32))
d = builder.or_(d_shl, d_lo_64)

return builder.bitcast(d, return_type)

sig = signature(a_type, membermask_type, a_type, b_type)

return sig, codegen
Loading