Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions numba_cuda/numba/cuda/api_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from numba import types
from numba.core import cgutils
import numpy as np


Expand Down Expand Up @@ -28,3 +30,26 @@ def _fill_stride_by_order(shape, dtype, order):
else:
raise ValueError('must be either C/F order')
return tuple(strides)


def normalize_indices(context, builder, indty, inds, aryty, valty):
"""
Convert integer indices into tuple of intp
"""
if indty in types.integer_domain:
indty = types.UniTuple(dtype=indty, count=1)
indices = [inds]
else:
indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
indices = [context.cast(builder, i, t, types.intp)
for t, i in zip(indty, indices)]

dtype = aryty.dtype
if dtype != valty:
raise TypeError("expect %s but got %s" % (dtype, valty))

if aryty.ndim != len(indty):
raise TypeError("indexing %d-D array with %d-D index" %
(aryty.ndim, len(indty)))

return indty, indices
241 changes: 241 additions & 0 deletions numba_cuda/numba/cuda/cache_hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from llvmlite import ir
from numba import types
from numba.core import cgutils
from numba.core.extending import intrinsic, overload
from numba.core.errors import NumbaTypeError
from numba.cuda.api_util import normalize_indices

# Docs references:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#load-functions-using-cache-hints


def ldca(array, i):
"""Generate a `ld.global.ca` instruction for element `i` of an array."""


def ldcg(array, i):
"""Generate a `ld.global.cg` instruction for element `i` of an array."""


def ldcs(array, i):
"""Generate a `ld.global.cs` instruction for element `i` of an array."""


def ldlu(array, i):
"""Generate a `ld.global.lu` instruction for element `i` of an array."""


def ldcv(array, i):
"""Generate a `ld.global.cv` instruction for element `i` of an array."""


def stcg(array, i, value):
"""Generate a `st.global.cg` instruction for element `i` of an array."""


def stcs(array, i, value):
"""Generate a `st.global.cs` instruction for element `i` of an array."""


def stwb(array, i, value):
"""Generate a `st.global.wb` instruction for element `i` of an array."""


def stwt(array, i, value):
"""Generate a `st.global.wt` instruction for element `i` of an array."""


# See
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#restricted-use-of-sub-word-sizes
# for background on the choice of "r" for 8-bit operands - there is
# no constraint for 8-bit operands, but the operand for loads and
# stores is permitted to be greater than 8 bits.
CONSTRAINT_MAP = {
1: "b",
8: "r",
16: "h",
32: "r",
64: "l",
128: "q"
}


def _validate_arguments(instruction, array, index):
if not isinstance(array, types.Array):
msg = f"{instruction} operates on arrays. Got type {array}"
raise NumbaTypeError(msg)

valid_index = False

if isinstance(index, types.Integer):
if array.ndim != 1:
msg = f"Expected {array.ndim} indices, got a scalar"
raise NumbaTypeError(msg)
valid_index = True

if isinstance(index, types.UniTuple):
if index.count != array.ndim:
msg = f"Expected {array.ndim} indices, got {index.count}"
raise NumbaTypeError(msg)

if all([isinstance(t, types.Integer) for t in index.dtype]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if all([isinstance(t, types.Integer) for t in index.dtype]):
if all(isinstance(t, types.Integer) for t in index.dtype):

No reason to create a list if you don't need to.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will apply this change in the follow up PR :)

valid_index = True

if not valid_index:
raise NumbaTypeError(f"{index} is not a valid index")


def ld_cache_operator(operator):
@intrinsic
def impl(typingctx, array, index):
_validate_arguments(f"ld{operator}", array, index)

# Need to validate bitwidth

signature = array.dtype(array, index)

def codegen(context, builder, sig, args):
array_type, index_type = sig.args
loaded_type = context.get_value_type(array_type.dtype)
ptr_type = loaded_type.as_pointer()
ldcs_type = ir.FunctionType(loaded_type, [ptr_type])

array, indices = args

index_type, indices = normalize_indices(context, builder,
index_type, indices,
array_type,
array_type.dtype)
array_struct = context.make_array(array_type)(context, builder,
value=array)
ptr = cgutils.get_item_pointer(context, builder, array_type,
array_struct, indices,
wraparound=True)

bitwidth = array_type.dtype.bitwidth
inst = f"ld.global.{operator}.b{bitwidth}"
constraints = f"={CONSTRAINT_MAP[bitwidth]},l"
ldcs = ir.InlineAsm(ldcs_type, f"{inst} $0, [$1];", constraints)
return builder.call(ldcs, [ptr])

return signature, codegen

return impl


ldca_intrinsic = ld_cache_operator("ca")
ldcg_intrinsic = ld_cache_operator("cg")
ldcs_intrinsic = ld_cache_operator("cs")
ldlu_intrinsic = ld_cache_operator("lu")
ldcv_intrinsic = ld_cache_operator("cv")


def st_cache_operator(operator):
@intrinsic
def impl(typingctx, array, index, value):
_validate_arguments(f"st{operator}", array, index)

# Need to validate bitwidth

signature = types.void(array, index, value)

def codegen(context, builder, sig, args):
array_type, index_type, value_type = sig.args
stored_type = context.get_value_type(array_type.dtype)
ptr_type = stored_type.as_pointer()
stcs_type = ir.FunctionType(ir.VoidType(), [ptr_type, stored_type])

array, indices, value = args

index_type, indices = normalize_indices(context, builder,
index_type, indices,
array_type,
array_type.dtype)
array_struct = context.make_array(array_type)(context, builder,
value=array)
ptr = cgutils.get_item_pointer(context, builder, array_type,
array_struct, indices,
wraparound=True)

casted_value = context.cast(builder, value, value_type,
array_type.dtype)

bitwidth = array_type.dtype.bitwidth
inst = f"st.global.{operator}.b{bitwidth}"
constraints = f"l,{CONSTRAINT_MAP[bitwidth]},~{{memory}}"
stcs = ir.InlineAsm(stcs_type, f"{inst} [$0], $1;", constraints)
builder.call(stcs, [ptr, casted_value])

return signature, codegen

return impl


stcg_intrinsic = st_cache_operator("cg")
stcs_intrinsic = st_cache_operator("cs")
stwb_intrinsic = st_cache_operator("wb")
stwt_intrinsic = st_cache_operator("wt")


@overload(ldca, target='cuda')
def ol_ldca(array, i):
def impl(array, i):
return ldca_intrinsic(array, i)
return impl
Comment on lines +183 to +185
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't this just be:

Suggested change
def impl(array, i):
return ldca_intrinsic(array, i)
return impl
return lcda_intrinsic

for each of these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An overload should return a Python function that gets jit-compiled to become the implementation of the function it's overloading. ldca_intrinsic is an intrinsic, called from within the compiled function.

I suspect it will not work to return an intrinsic as an overload implementation, but even if it did, it would feel jarring to me to contract a level of abstraction in the implementation here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does seem really weird to my eye that a function foo(*args) whose only line is to call bar(*args) cannot simply be replaced with the call to bar(*args).

This feels like a numbaism perhaps that violates basic substitution rules, but this is of course not a blocking comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the confusion here comes from reading the code as if it's going to be interpreted by the Python interpreter, rather than seeing it as a form of metaprogramming, which is what functions decorated with @overload implement.

An attempt to summarise the pertinent points:

  • An @overload function returns a Python function that Numba-CUDA compiles.
  • An @intrinsic function (like ldca_intrinsic, generated by ld_cache_operator() above), is a function that returns a tuple of (signature, codegen), where:
    • signature is the typing signature that Numba-CUDA uses during type inference to determine and validate the function's argument and return types, and
    • codegen is a function that Numba-CUDA calls to generate the LLVM IR for the implementation.
  • During the compilation process (when impl() is being compiled), the typing and lowering for intrinsics are resolved, and the implementation of the intrinsic generated by the codegen() function is inserted into the generated code.
  • In the compilation process for impl(), type inference and code generation implements ldca_intrinsic() as a function that returns a scalar value, and accepts and array and index as arguments. The typing is defined by the signature on line 96, and the code generation function follows below it.
  • Therefore, if you replace return impl with return ldca_intrinsic in ol_ldca(), you have replaced a function that accepts an array and an index then returns a scalar (impl(array, i) -> array.dtype), with one that accepts a typing context and a number of LLVM IR values, and returns a signature and code generation function (impl(typingctx, array, index) -> (Signature, Function)).

I hope this clarifies things a bit, but for a more complete understanding I can't see a shortcut that avoids working through the low- and high-level extension API documentation for Numba:

That said, I do have a couple of worked examples that show the whole flow in one notebook for each of these APIs for the CUDA target, which may also help:

They may be a little out-of-date and need a couple of bits updating, but the general flow of them is still relevant.

We're using the High-level API here (that @overload and @intrinsic are part of), but it's probably hard to understand the High-level API without first understanding the Low-level API. The High-level API is intended to make it quicker and easier to write Numba extensions, but in my view the main thing it provides is some shorthand for a lot of Low-level API work.

Finally - what happens if you implement the suggested change (modulo the typo in the name of the function above)? You will get:

AssertionError: Implementation function returned by `@overload` has an unexpected type.  Got <intrinsic impl>

cc @kaeun97 as the explanation might be helpful in understanding the PR as a whole.



@overload(ldcg, target='cuda')
def ol_ldcg(array, i):
def impl(array, i):
return ldcg_intrinsic(array, i)
return impl


@overload(ldcs, target='cuda')
def ol_ldcs(array, i):
def impl(array, i):
return ldcs_intrinsic(array, i)
return impl


@overload(ldlu, target='cuda')
def ol_ldlu(array, i):
def impl(array, i):
return ldlu_intrinsic(array, i)
return impl


@overload(ldcv, target='cuda')
def ol_ldcv(array, i):
def impl(array, i):
return ldcv_intrinsic(array, i)
return impl


@overload(stcg, target='cuda')
def ol_stcg(array, i, value):
def impl(array, i, value):
return stcg_intrinsic(array, i, value)
return impl


@overload(stcs, target='cuda')
def ol_stcs(array, i, value):
def impl(array, i, value):
return stcs_intrinsic(array, i, value)
return impl


@overload(stwb, target='cuda')
def ol_stwb(array, i, value):
def impl(array, i, value):
return stwb_intrinsic(array, i, value)
return impl


@overload(stwt, target='cuda')
def ol_stwt(array, i, value):
def impl(array, i, value):
return stwt_intrinsic(array, i, value)
return impl
32 changes: 5 additions & 27 deletions numba_cuda/numba/cuda/cudaimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numba.np.npyimpl import register_ufuncs
from .cudadrv import nvvm
from numba import cuda
from numba.cuda.api_util import normalize_indices
from numba.cuda import nvvmutils, stubs, errors
from numba.cuda.types import dim3, CUDADispatcher

Expand Down Expand Up @@ -692,38 +693,15 @@ def impl(context, builder, sig, args):
lower(math.degrees, types.f8)(gen_deg_rad(_rad2deg))


def _normalize_indices(context, builder, indty, inds, aryty, valty):
"""
Convert integer indices into tuple of intp
"""
if indty in types.integer_domain:
indty = types.UniTuple(dtype=indty, count=1)
indices = [inds]
else:
indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
indices = [context.cast(builder, i, t, types.intp)
for t, i in zip(indty, indices)]

dtype = aryty.dtype
if dtype != valty:
raise TypeError("expect %s but got %s" % (dtype, valty))

if aryty.ndim != len(indty):
raise TypeError("indexing %d-D array with %d-D index" %
(aryty.ndim, len(indty)))

return indty, indices


def _atomic_dispatcher(dispatch_fn):
def imp(context, builder, sig, args):
# The common argument handling code
aryty, indty, valty = sig.args
ary, inds, val = args
dtype = aryty.dtype

indty, indices = _normalize_indices(context, builder, indty, inds,
aryty, valty)
indty, indices = normalize_indices(context, builder, indty, inds,
aryty, valty)

lary = context.make_array(aryty)(context, builder, ary)
ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices,
Expand Down Expand Up @@ -917,8 +895,8 @@ def ptx_atomic_cas(context, builder, sig, args):
aryty, indty, oldty, valty = sig.args
ary, inds, old, val = args

indty, indices = _normalize_indices(context, builder, indty, inds, aryty,
valty)
indty, indices = normalize_indices(context, builder, indty, inds, aryty,
valty)

lary = context.make_array(aryty)(context, builder, ary)
ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices,
Expand Down
2 changes: 2 additions & 0 deletions numba_cuda/numba/cuda/device_init.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Re export
import sys
from numba.cuda import cg
from numba.cuda.cache_hints import (ldca, ldcg, ldcs, ldlu, ldcv, stcg, stcs,
stwb, stwt)
from .stubs import (threadIdx, blockIdx, blockDim, gridDim, laneid, warpsize,
syncwarp, shared, local, const, atomic,
shfl_sync_intrinsic, vote_sync_intrinsic, match_any_sync,
Expand Down
Loading