diff --git a/numba_cuda/numba/cuda/api_util.py b/numba_cuda/numba/cuda/api_util.py index b8bffb7c1..360679eba 100644 --- a/numba_cuda/numba/cuda/api_util.py +++ b/numba_cuda/numba/cuda/api_util.py @@ -1,3 +1,5 @@ +from numba import types +from numba.core import cgutils import numpy as np @@ -28,3 +30,26 @@ def _fill_stride_by_order(shape, dtype, order): else: raise ValueError('must be either C/F order') return tuple(strides) + + +def normalize_indices(context, builder, indty, inds, aryty, valty): + """ + Convert integer indices into tuple of intp + """ + if indty in types.integer_domain: + indty = types.UniTuple(dtype=indty, count=1) + indices = [inds] + else: + indices = cgutils.unpack_tuple(builder, inds, count=len(indty)) + indices = [context.cast(builder, i, t, types.intp) + for t, i in zip(indty, indices)] + + dtype = aryty.dtype + if dtype != valty: + raise TypeError("expect %s but got %s" % (dtype, valty)) + + if aryty.ndim != len(indty): + raise TypeError("indexing %d-D array with %d-D index" % + (aryty.ndim, len(indty))) + + return indty, indices diff --git a/numba_cuda/numba/cuda/cache_hints.py b/numba_cuda/numba/cuda/cache_hints.py new file mode 100644 index 000000000..60dc4644f --- /dev/null +++ b/numba_cuda/numba/cuda/cache_hints.py @@ -0,0 +1,241 @@ +from llvmlite import ir +from numba import types +from numba.core import cgutils +from numba.core.extending import intrinsic, overload +from numba.core.errors import NumbaTypeError +from numba.cuda.api_util import normalize_indices + +# Docs references: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld +# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#load-functions-using-cache-hints + + +def ldca(array, i): + """Generate a `ld.global.ca` instruction for element `i` of an array.""" + + +def ldcg(array, i): + """Generate a `ld.global.cg` instruction for element `i` of an array.""" + + +def ldcs(array, i): + """Generate a `ld.global.cs` instruction for element `i` of an array.""" + + +def ldlu(array, i): + """Generate a `ld.global.lu` instruction for element `i` of an array.""" + + +def ldcv(array, i): + """Generate a `ld.global.cv` instruction for element `i` of an array.""" + + +def stcg(array, i, value): + """Generate a `st.global.cg` instruction for element `i` of an array.""" + + +def stcs(array, i, value): + """Generate a `st.global.cs` instruction for element `i` of an array.""" + + +def stwb(array, i, value): + """Generate a `st.global.wb` instruction for element `i` of an array.""" + + +def stwt(array, i, value): + """Generate a `st.global.wt` instruction for element `i` of an array.""" + + +# See +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#restricted-use-of-sub-word-sizes +# for background on the choice of "r" for 8-bit operands - there is +# no constraint for 8-bit operands, but the operand for loads and +# stores is permitted to be greater than 8 bits. +CONSTRAINT_MAP = { + 1: "b", + 8: "r", + 16: "h", + 32: "r", + 64: "l", + 128: "q" +} + + +def _validate_arguments(instruction, array, index): + if not isinstance(array, types.Array): + msg = f"{instruction} operates on arrays. Got type {array}" + raise NumbaTypeError(msg) + + valid_index = False + + if isinstance(index, types.Integer): + if array.ndim != 1: + msg = f"Expected {array.ndim} indices, got a scalar" + raise NumbaTypeError(msg) + valid_index = True + + if isinstance(index, types.UniTuple): + if index.count != array.ndim: + msg = f"Expected {array.ndim} indices, got {index.count}" + raise NumbaTypeError(msg) + + if all([isinstance(t, types.Integer) for t in index.dtype]): + valid_index = True + + if not valid_index: + raise NumbaTypeError(f"{index} is not a valid index") + + +def ld_cache_operator(operator): + @intrinsic + def impl(typingctx, array, index): + _validate_arguments(f"ld{operator}", array, index) + + # Need to validate bitwidth + + signature = array.dtype(array, index) + + def codegen(context, builder, sig, args): + array_type, index_type = sig.args + loaded_type = context.get_value_type(array_type.dtype) + ptr_type = loaded_type.as_pointer() + ldcs_type = ir.FunctionType(loaded_type, [ptr_type]) + + array, indices = args + + index_type, indices = normalize_indices(context, builder, + index_type, indices, + array_type, + array_type.dtype) + array_struct = context.make_array(array_type)(context, builder, + value=array) + ptr = cgutils.get_item_pointer(context, builder, array_type, + array_struct, indices, + wraparound=True) + + bitwidth = array_type.dtype.bitwidth + inst = f"ld.global.{operator}.b{bitwidth}" + constraints = f"={CONSTRAINT_MAP[bitwidth]},l" + ldcs = ir.InlineAsm(ldcs_type, f"{inst} $0, [$1];", constraints) + return builder.call(ldcs, [ptr]) + + return signature, codegen + + return impl + + +ldca_intrinsic = ld_cache_operator("ca") +ldcg_intrinsic = ld_cache_operator("cg") +ldcs_intrinsic = ld_cache_operator("cs") +ldlu_intrinsic = ld_cache_operator("lu") +ldcv_intrinsic = ld_cache_operator("cv") + + +def st_cache_operator(operator): + @intrinsic + def impl(typingctx, array, index, value): + _validate_arguments(f"st{operator}", array, index) + + # Need to validate bitwidth + + signature = types.void(array, index, value) + + def codegen(context, builder, sig, args): + array_type, index_type, value_type = sig.args + stored_type = context.get_value_type(array_type.dtype) + ptr_type = stored_type.as_pointer() + stcs_type = ir.FunctionType(ir.VoidType(), [ptr_type, stored_type]) + + array, indices, value = args + + index_type, indices = normalize_indices(context, builder, + index_type, indices, + array_type, + array_type.dtype) + array_struct = context.make_array(array_type)(context, builder, + value=array) + ptr = cgutils.get_item_pointer(context, builder, array_type, + array_struct, indices, + wraparound=True) + + casted_value = context.cast(builder, value, value_type, + array_type.dtype) + + bitwidth = array_type.dtype.bitwidth + inst = f"st.global.{operator}.b{bitwidth}" + constraints = f"l,{CONSTRAINT_MAP[bitwidth]},~{{memory}}" + stcs = ir.InlineAsm(stcs_type, f"{inst} [$0], $1;", constraints) + builder.call(stcs, [ptr, casted_value]) + + return signature, codegen + + return impl + + +stcg_intrinsic = st_cache_operator("cg") +stcs_intrinsic = st_cache_operator("cs") +stwb_intrinsic = st_cache_operator("wb") +stwt_intrinsic = st_cache_operator("wt") + + +@overload(ldca, target='cuda') +def ol_ldca(array, i): + def impl(array, i): + return ldca_intrinsic(array, i) + return impl + + +@overload(ldcg, target='cuda') +def ol_ldcg(array, i): + def impl(array, i): + return ldcg_intrinsic(array, i) + return impl + + +@overload(ldcs, target='cuda') +def ol_ldcs(array, i): + def impl(array, i): + return ldcs_intrinsic(array, i) + return impl + + +@overload(ldlu, target='cuda') +def ol_ldlu(array, i): + def impl(array, i): + return ldlu_intrinsic(array, i) + return impl + + +@overload(ldcv, target='cuda') +def ol_ldcv(array, i): + def impl(array, i): + return ldcv_intrinsic(array, i) + return impl + + +@overload(stcg, target='cuda') +def ol_stcg(array, i, value): + def impl(array, i, value): + return stcg_intrinsic(array, i, value) + return impl + + +@overload(stcs, target='cuda') +def ol_stcs(array, i, value): + def impl(array, i, value): + return stcs_intrinsic(array, i, value) + return impl + + +@overload(stwb, target='cuda') +def ol_stwb(array, i, value): + def impl(array, i, value): + return stwb_intrinsic(array, i, value) + return impl + + +@overload(stwt, target='cuda') +def ol_stwt(array, i, value): + def impl(array, i, value): + return stwt_intrinsic(array, i, value) + return impl diff --git a/numba_cuda/numba/cuda/cudaimpl.py b/numba_cuda/numba/cuda/cudaimpl.py index 5de024d1c..34a57d403 100644 --- a/numba_cuda/numba/cuda/cudaimpl.py +++ b/numba_cuda/numba/cuda/cudaimpl.py @@ -13,6 +13,7 @@ from numba.np.npyimpl import register_ufuncs from .cudadrv import nvvm from numba import cuda +from numba.cuda.api_util import normalize_indices from numba.cuda import nvvmutils, stubs, errors from numba.cuda.types import dim3, CUDADispatcher @@ -692,29 +693,6 @@ def impl(context, builder, sig, args): lower(math.degrees, types.f8)(gen_deg_rad(_rad2deg)) -def _normalize_indices(context, builder, indty, inds, aryty, valty): - """ - Convert integer indices into tuple of intp - """ - if indty in types.integer_domain: - indty = types.UniTuple(dtype=indty, count=1) - indices = [inds] - else: - indices = cgutils.unpack_tuple(builder, inds, count=len(indty)) - indices = [context.cast(builder, i, t, types.intp) - for t, i in zip(indty, indices)] - - dtype = aryty.dtype - if dtype != valty: - raise TypeError("expect %s but got %s" % (dtype, valty)) - - if aryty.ndim != len(indty): - raise TypeError("indexing %d-D array with %d-D index" % - (aryty.ndim, len(indty))) - - return indty, indices - - def _atomic_dispatcher(dispatch_fn): def imp(context, builder, sig, args): # The common argument handling code @@ -722,8 +700,8 @@ def imp(context, builder, sig, args): ary, inds, val = args dtype = aryty.dtype - indty, indices = _normalize_indices(context, builder, indty, inds, - aryty, valty) + indty, indices = normalize_indices(context, builder, indty, inds, + aryty, valty) lary = context.make_array(aryty)(context, builder, ary) ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices, @@ -917,8 +895,8 @@ def ptx_atomic_cas(context, builder, sig, args): aryty, indty, oldty, valty = sig.args ary, inds, old, val = args - indty, indices = _normalize_indices(context, builder, indty, inds, aryty, - valty) + indty, indices = normalize_indices(context, builder, indty, inds, aryty, + valty) lary = context.make_array(aryty)(context, builder, ary) ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices, diff --git a/numba_cuda/numba/cuda/device_init.py b/numba_cuda/numba/cuda/device_init.py index 9df5ae99d..85accbe7b 100644 --- a/numba_cuda/numba/cuda/device_init.py +++ b/numba_cuda/numba/cuda/device_init.py @@ -1,6 +1,8 @@ # Re export import sys from numba.cuda import cg +from numba.cuda.cache_hints import (ldca, ldcg, ldcs, ldlu, ldcv, stcg, stcs, + stwb, stwt) from .stubs import (threadIdx, blockIdx, blockDim, gridDim, laneid, warpsize, syncwarp, shared, local, const, atomic, shfl_sync_intrinsic, vote_sync_intrinsic, match_any_sync, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py b/numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py new file mode 100644 index 000000000..8ed91355f --- /dev/null +++ b/numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py @@ -0,0 +1,112 @@ +from numba import cuda, errors, typeof, types +from numba.cuda.testing import unittest, CUDATestCase +import numpy as np + +tested_types = ( + np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, + np.float16, np.float32, np.float64, +) + +complex_types = ( + np.complex64, np.complex128 +) + +load_operators = ( + (cuda.ldca, 'ca'), + (cuda.ldcg, 'cg'), + (cuda.ldcs, 'cs'), + (cuda.ldlu, 'lu'), + (cuda.ldcv, 'cv') +) +store_operators = ( + (cuda.stcg, 'cg'), + (cuda.stcs, 'cs'), + (cuda.stwb, 'wb'), + (cuda.stwt, 'wt') +) + + +class TestCacheHints(CUDATestCase): + def test_loads(self): + for operator, modifier in load_operators: + @cuda.jit + def f(r, x): + for i in range(len(r)): + r[i] = operator(x, i) + + for ty in tested_types: + with self.subTest(operator=operator, ty=ty): + x = np.arange(5).astype(ty) + r = np.zeros_like(x) + + f[1, 1](r, x) + np.testing.assert_equal(r, x) + + # Check PTX contains a cache-policy load instruction + numba_type = typeof(x) + bitwidth = numba_type.dtype.bitwidth + sig = (numba_type, numba_type) + ptx, _ = cuda.compile_ptx(f, sig) + + self.assertIn(f"ld.global.{modifier}.b{bitwidth}", ptx) + + def test_stores(self): + for operator, modifier in store_operators: + @cuda.jit + def f(r, x): + for i in range(len(r)): + operator(r, i, x[i]) + + for ty in tested_types: + with self.subTest(operator=operator, ty=ty): + x = np.arange(5).astype(ty) + r = np.zeros_like(x) + + f[1, 1](r, x) + np.testing.assert_equal(r, x) + + # Check PTX contains a cache-policy store instruction + numba_type = typeof(x) + bitwidth = numba_type.dtype.bitwidth + sig = (numba_type, numba_type) + ptx, _ = cuda.compile_ptx(f, sig) + + self.assertIn(f"st.global.{modifier}.b{bitwidth}", ptx) + + def test_bad_indices(self): + def float_indices(x): + cuda.ldcs(x, 1.0) + + sig_1d = (types.float32[::1],) + + msg = "float64 is not a valid index" + with self.assertRaisesRegex(errors.TypingError, msg): + cuda.compile_ptx(float_indices, sig_1d) + + def too_long_indices(x): + cuda.ldcs(x, (1, 2)) + + msg = "Expected 1 indices, got 2" + with self.assertRaisesRegex(errors.TypingError, msg): + cuda.compile_ptx(too_long_indices, sig_1d) + + def too_short_indices_scalar(x): + cuda.ldcs(x, 1) + + def too_short_indices_tuple(x): + cuda.ldcs(x, (1,)) + + sig_2d = (types.float32[:,::1],) + + msg = "Expected 2 indices, got a scalar" + with self.assertRaisesRegex(errors.TypingError, msg): + cuda.compile_ptx(too_short_indices_scalar, sig_2d) + + msg = "Expected 2 indices, got 1" + with self.assertRaisesRegex(errors.TypingError, msg): + cuda.compile_ptx(too_short_indices_tuple, sig_2d) + + +if __name__ == '__main__': + unittest.main()