diff --git a/docs/source/reference/cache_hints.rst b/docs/source/reference/cache_hints.rst new file mode 100644 index 000000000..233e4378b --- /dev/null +++ b/docs/source/reference/cache_hints.rst @@ -0,0 +1,63 @@ +.. + SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + SPDX-License-Identifier: BSD-2-Clause + +.. _cache-hints: + +Cache Hints for Memory Operations +================================= + +These functions provide explicit control over caching behavior for memory +operations. They generate PTX instructions with cache policy hints that can +optimize specific memory access patterns. All functions support arrays or +pointers with all bitwidths of signed/unsigned integer and floating-point +types. + +.. seealso:: `Cache Operators + `_ + in the PTX ISA documentation. + +.. function:: numba.cuda.ldca(array, i) + + Load element ``i`` from ``array`` with cache-all policy (``ld.global.ca``). This + is the default caching behavior. + +.. function:: numba.cuda.ldcg(array, i) + + Load element ``i`` from ``array`` with cache-global policy (``ld.global.cg``). + Useful for data shared across thread blocks. + +.. function:: numba.cuda.ldcs(array, i) + + Load element ``i`` from ``array`` with cache-streaming policy + (``ld.global.cs``). Optimized for streaming data accessed once. + +.. function:: numba.cuda.ldlu(array, i) + + Load element ``i`` from ``array`` with last-use policy (``ld.global.lu``). + Indicates data is unlikely to be reused. + +.. function:: numba.cuda.ldcv(array, i) + + Load element ``i`` from ``array`` with cache-volatile policy (``ld.global.cv``). + Used for volatile data that may change externally. + +.. function:: numba.cuda.stcg(array, i, value) + + Store ``value`` to ``array[i]`` with cache-global policy (``st.global.cg``). + Useful for data shared across thread blocks. + +.. function:: numba.cuda.stcs(array, i, value) + + Store ``value`` to ``array[i]`` with cache-streaming policy (``st.global.cs``). + Optimized for streaming writes. + +.. function:: numba.cuda.stwb(array, i, value) + + Store ``value`` to ``array[i]`` with write-back policy (``st.global.wb``). This + is the default caching behavior. + +.. function:: numba.cuda.stwt(array, i, value) + + Store ``value`` to ``array[i]`` with write-through policy (``st.global.wt``). + Writes through cache hierarchy to memory. diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst index 040eed00c..89ca12a4a 100644 --- a/docs/source/reference/index.rst +++ b/docs/source/reference/index.rst @@ -9,6 +9,7 @@ Reference documentation host.rst kernel.rst + cache_hints.rst types.rst memory.rst libdevice.rst diff --git a/numba_cuda/numba/cuda/api_util.py b/numba_cuda/numba/cuda/api_util.py index 580b4eb20..d1bcfbb8b 100644 --- a/numba_cuda/numba/cuda/api_util.py +++ b/numba_cuda/numba/cuda/api_util.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause +from numba import types +from numba.core import cgutils import numpy as np import functools @@ -47,3 +49,28 @@ def _fill_stride_by_order(shape, dtype, order): else: raise ValueError("must be either C/F order") return tuple(strides) + + +def normalize_indices(context, builder, indty, inds, aryty, valty): + """ + Convert integer indices into tuple of intp + """ + if indty in types.integer_domain: + indty = types.UniTuple(dtype=indty, count=1) + indices = [inds] + else: + indices = cgutils.unpack_tuple(builder, inds, count=len(indty)) + indices = [ + context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices) + ] + + dtype = aryty.dtype + if dtype != valty: + raise TypeError("expect %s but got %s" % (dtype, valty)) + + if aryty.ndim != len(indty): + raise TypeError( + "indexing %d-D array with %d-D index" % (aryty.ndim, len(indty)) + ) + + return indty, indices diff --git a/numba_cuda/numba/cuda/cache_hints.py b/numba_cuda/numba/cuda/cache_hints.py new file mode 100644 index 000000000..557e4cdbe --- /dev/null +++ b/numba_cuda/numba/cuda/cache_hints.py @@ -0,0 +1,287 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from llvmlite import ir +from numba import types +from numba.cuda import cgutils +from numba.cuda.extending import intrinsic, overload +from numba.cuda.core.errors import NumbaTypeError +from numba.cuda.api_util import normalize_indices + +# Docs references: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld +# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#load-functions-using-cache-hints + + +def ldca(array, i): + """Generate a `ld.global.ca` instruction for element `i` of an array.""" + + +def ldcg(array, i): + """Generate a `ld.global.cg` instruction for element `i` of an array.""" + + +def ldcs(array, i): + """Generate a `ld.global.cs` instruction for element `i` of an array.""" + + +def ldlu(array, i): + """Generate a `ld.global.lu` instruction for element `i` of an array.""" + + +def ldcv(array, i): + """Generate a `ld.global.cv` instruction for element `i` of an array.""" + + +def stcg(array, i, value): + """Generate a `st.global.cg` instruction for element `i` of an array.""" + + +def stcs(array, i, value): + """Generate a `st.global.cs` instruction for element `i` of an array.""" + + +def stwb(array, i, value): + """Generate a `st.global.wb` instruction for element `i` of an array.""" + + +def stwt(array, i, value): + """Generate a `st.global.wt` instruction for element `i` of an array.""" + + +# See +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#restricted-use-of-sub-word-sizes +# for background on the choice of "r" for 8-bit operands - there is +# no constraint for 8-bit operands, but the operand for loads and +# stores is permitted to be greater than 8 bits. +CONSTRAINT_MAP = {1: "b", 8: "r", 16: "h", 32: "r", 64: "l", 128: "q"} + + +def _validate_arguments(instruction, array, index): + is_array = isinstance(array, types.Array) + is_pointer = isinstance(array, types.CPointer) + if not (is_array or is_pointer): + msg = f"{instruction} operates on arrays or pointers. Got type {array}" + raise NumbaTypeError(msg) + + valid_index = False + + if isinstance(index, types.Integer): + if is_array and array.ndim != 1: + # for pointers, any integer index is valid + msg = f"Expected {array.ndim} indices, got a scalar" + raise NumbaTypeError(msg) + valid_index = True + + if isinstance(index, types.UniTuple): + if is_pointer: + msg = f"Pointers only support scalar indexing, got tuple of {index.count}" + raise NumbaTypeError(msg) + + if index.count != array.ndim: + msg = f"Expected {array.ndim} indices, got {index.count}" + raise NumbaTypeError(msg) + + if isinstance(index.dtype, types.Integer): + valid_index = True + + if not valid_index: + raise NumbaTypeError(f"{index} is not a valid index") + + +def _validate_bitwidth(instruction, array): + dtype = array.dtype + + if not isinstance(dtype, (types.Integer, types.Float)): + msg = ( + f"{instruction} requires array of integer or float type, " + f"got {dtype}" + ) + raise NumbaTypeError(msg) + + bitwidth = dtype.bitwidth + if bitwidth not in CONSTRAINT_MAP: + valid_widths = sorted(CONSTRAINT_MAP.keys()) + msg = ( + f"{instruction} requires array dtype with bitwidth " + f"in {valid_widths}, got bitwidth {bitwidth}" + ) + raise NumbaTypeError(msg) + + +def _get_element_pointer( + context, builder, index_type, indices, array_type, array +): + if isinstance(array_type, types.CPointer): + return builder.gep(array, [indices]) + else: + index_type, indices = normalize_indices( + context, + builder, + index_type, + indices, + array_type, + array_type.dtype, + ) + array_struct = context.make_array(array_type)( + context, builder, value=array + ) + return cgutils.get_item_pointer( + context, + builder, + array_type, + array_struct, + indices, + wraparound=True, + ) + + +def ld_cache_operator(operator): + @intrinsic + def impl(typingctx, array, index): + _validate_arguments(f"ld{operator}", array, index) + _validate_bitwidth(f"ld{operator}", array) + + signature = array.dtype(array, index) + + def codegen(context, builder, sig, args): + array_type, index_type = sig.args + loaded_type = context.get_value_type(array_type.dtype) + ptr_type = loaded_type.as_pointer() + ldcs_type = ir.FunctionType(loaded_type, [ptr_type]) + + array, indices = args + + ptr = _get_element_pointer( + context, builder, index_type, indices, array_type, array + ) + + bitwidth = array_type.dtype.bitwidth + inst = f"ld.global.{operator}.b{bitwidth}" + constraints = f"={CONSTRAINT_MAP[bitwidth]},l" + ldcs = ir.InlineAsm(ldcs_type, f"{inst} $0, [$1];", constraints) + return builder.call(ldcs, [ptr]) + + return signature, codegen + + return impl + + +ldca_intrinsic = ld_cache_operator("ca") +ldcg_intrinsic = ld_cache_operator("cg") +ldcs_intrinsic = ld_cache_operator("cs") +ldlu_intrinsic = ld_cache_operator("lu") +ldcv_intrinsic = ld_cache_operator("cv") + + +def st_cache_operator(operator): + @intrinsic + def impl(typingctx, array, index, value): + _validate_arguments(f"st{operator}", array, index) + _validate_bitwidth(f"st{operator}", array) + + signature = types.void(array, index, value) + + def codegen(context, builder, sig, args): + array_type, index_type, value_type = sig.args + stored_type = context.get_value_type(array_type.dtype) + ptr_type = stored_type.as_pointer() + stcs_type = ir.FunctionType(ir.VoidType(), [ptr_type, stored_type]) + + array, indices, value = args + + ptr = _get_element_pointer( + context, builder, index_type, indices, array_type, array + ) + + casted_value = context.cast( + builder, value, value_type, array_type.dtype + ) + + bitwidth = array_type.dtype.bitwidth + inst = f"st.global.{operator}.b{bitwidth}" + constraints = f"l,{CONSTRAINT_MAP[bitwidth]},~{{memory}}" + stcs = ir.InlineAsm(stcs_type, f"{inst} [$0], $1;", constraints) + builder.call(stcs, [ptr, casted_value]) + + return signature, codegen + + return impl + + +stcg_intrinsic = st_cache_operator("cg") +stcs_intrinsic = st_cache_operator("cs") +stwb_intrinsic = st_cache_operator("wb") +stwt_intrinsic = st_cache_operator("wt") + + +@overload(ldca, target="cuda") +def ol_ldca(array, i): + def impl(array, i): + return ldca_intrinsic(array, i) + + return impl + + +@overload(ldcg, target="cuda") +def ol_ldcg(array, i): + def impl(array, i): + return ldcg_intrinsic(array, i) + + return impl + + +@overload(ldcs, target="cuda") +def ol_ldcs(array, i): + def impl(array, i): + return ldcs_intrinsic(array, i) + + return impl + + +@overload(ldlu, target="cuda") +def ol_ldlu(array, i): + def impl(array, i): + return ldlu_intrinsic(array, i) + + return impl + + +@overload(ldcv, target="cuda") +def ol_ldcv(array, i): + def impl(array, i): + return ldcv_intrinsic(array, i) + + return impl + + +@overload(stcg, target="cuda") +def ol_stcg(array, i, value): + def impl(array, i, value): + return stcg_intrinsic(array, i, value) + + return impl + + +@overload(stcs, target="cuda") +def ol_stcs(array, i, value): + def impl(array, i, value): + return stcs_intrinsic(array, i, value) + + return impl + + +@overload(stwb, target="cuda") +def ol_stwb(array, i, value): + def impl(array, i, value): + return stwb_intrinsic(array, i, value) + + return impl + + +@overload(stwt, target="cuda") +def ol_stwt(array, i, value): + def impl(array, i, value): + return stwt_intrinsic(array, i, value) + + return impl diff --git a/numba_cuda/numba/cuda/cudaimpl.py b/numba_cuda/numba/cuda/cudaimpl.py index e0133fe4c..a48c7951b 100644 --- a/numba_cuda/numba/cuda/cudaimpl.py +++ b/numba_cuda/numba/cuda/cudaimpl.py @@ -18,6 +18,7 @@ from numba.cuda.np.npyimpl import register_ufuncs from .cudadrv import nvvm from numba import cuda +from numba.cuda.api_util import normalize_indices from numba.cuda import nvvmutils, stubs from numba.cuda.types.ext_types import dim3, CUDADispatcher @@ -576,31 +577,6 @@ def impl(context, builder, sig, args): lower(math.degrees, types.f8)(gen_deg_rad(_rad2deg)) -def _normalize_indices(context, builder, indty, inds, aryty, valty): - """ - Convert integer indices into tuple of intp - """ - if indty in types.integer_domain: - indty = types.UniTuple(dtype=indty, count=1) - indices = [inds] - else: - indices = cgutils.unpack_tuple(builder, inds, count=len(indty)) - indices = [ - context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices) - ] - - dtype = aryty.dtype - if dtype != valty: - raise TypeError("expect %s but got %s" % (dtype, valty)) - - if aryty.ndim != len(indty): - raise TypeError( - "indexing %d-D array with %d-D index" % (aryty.ndim, len(indty)) - ) - - return indty, indices - - def _atomic_dispatcher(dispatch_fn): def imp(context, builder, sig, args): # The common argument handling code @@ -608,7 +584,7 @@ def imp(context, builder, sig, args): ary, inds, val = args dtype = aryty.dtype - indty, indices = _normalize_indices( + indty, indices = normalize_indices( context, builder, indty, inds, aryty, valty ) @@ -818,7 +794,7 @@ def ptx_atomic_cas(context, builder, sig, args): aryty, indty, oldty, valty = sig.args ary, inds, old, val = args - indty, indices = _normalize_indices( + indty, indices = normalize_indices( context, builder, indty, inds, aryty, valty ) diff --git a/numba_cuda/numba/cuda/device_init.py b/numba_cuda/numba/cuda/device_init.py index ed3f5249c..004a7b538 100644 --- a/numba_cuda/numba/cuda/device_init.py +++ b/numba_cuda/numba/cuda/device_init.py @@ -4,6 +4,17 @@ # Re export import sys from numba.cuda import cg +from numba.cuda.cache_hints import ( + ldca, + ldcg, + ldcs, + ldlu, + ldcv, + stcg, + stcs, + stwb, + stwt, +) from .stubs import ( threadIdx, blockIdx, diff --git a/numba_cuda/numba/cuda/simulator/api.py b/numba_cuda/numba/cuda/simulator/api.py index ce50a158c..7f98aebf7 100644 --- a/numba_cuda/numba/cuda/simulator/api.py +++ b/numba_cuda/numba/cuda/simulator/api.py @@ -164,3 +164,16 @@ def defer_cleanup(): def is_supported_version(): return True + + +# cache hints operations +ldca = None +ldcg = None +ldcs = None +ldlu = None +ldcv = None + +stcg = None +stcs = None +stwb = None +stwt = None diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py b/numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py new file mode 100644 index 000000000..17461e2b5 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from numba import cuda, errors, typeof, types +from numba.cuda.testing import unittest, CUDATestCase, skip_on_cudasim +import numpy as np + +tested_types = ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.float16, + np.float32, + np.float64, +) + +load_operators = ( + (cuda.ldca, "ca"), + (cuda.ldcg, "cg"), + (cuda.ldcs, "cs"), + (cuda.ldlu, "lu"), + (cuda.ldcv, "cv"), +) +store_operators = ( + (cuda.stcg, "cg"), + (cuda.stcs, "cs"), + (cuda.stwb, "wb"), + (cuda.stwt, "wt"), +) + + +@skip_on_cudasim("Cache hints not supported on simulator") +class TestCacheHints(CUDATestCase): + def test_loads(self): + for operator, modifier in load_operators: + + @cuda.jit + def f(r, x): + for i in range(len(r)): + r[i] = operator(x, i) + + def f_ptr(ptr, r, n): + for i in range(n): + r[i] = operator(ptr, i) + + for ty in tested_types: + x = np.arange(5).astype(ty) + numba_type = typeof(x) + bitwidth = numba_type.dtype.bitwidth + + with self.subTest(operator=operator, ty=ty, input_type="array"): + r = np.zeros_like(x) + f[1, 1](r, x) + np.testing.assert_equal(r, x) + + # Check PTX contains a cache-policy load instruction + sig = (numba_type, numba_type) + ptx = f.inspect_asm(signature=sig) + + self.assertIn(f"ld.global.{modifier}.b{bitwidth}", ptx) + + with self.subTest( + operator=operator, ty=ty, input_type="cpointer" + ): + ptr_type = types.CPointer(numba_type.dtype) + + sig = (ptr_type, numba_type, types.intp) + ptx, _ = cuda.compile_ptx(f_ptr, sig) + + self.assertIn(f"ld.global.{modifier}.b{bitwidth}", ptx) + + def test_loads_2d(self): + for operator, modifier in load_operators: + + @cuda.jit + def f(r, x): + for i in range(x.shape[0]): + for j in range(x.shape[1]): + r[i, j] = operator(x, (i, j)) + + for ty in tested_types: + x = np.arange(12).reshape(3, 4).astype(ty) + numba_type = typeof(x) + bitwidth = numba_type.dtype.bitwidth + + with self.subTest(operator=operator, ty=ty, dims=2): + r = np.zeros_like(x) + f[1, 1](r, x) + np.testing.assert_equal(r, x) + + sig = (numba_type, numba_type) + ptx = f.inspect_asm(signature=sig) + + self.assertIn(f"ld.global.{modifier}.b{bitwidth}", ptx) + + def test_stores(self): + for operator, modifier in store_operators: + + @cuda.jit + def f(r, x): + for i in range(len(r)): + operator(r, i, x[i]) + + def f_ptr(ptr, values, n): + for i in range(n): + operator(ptr, i, values[i]) + + for ty in tested_types: + x = np.arange(5).astype(ty) + numba_type = typeof(x) + bitwidth = numba_type.dtype.bitwidth + + with self.subTest(operator=operator, ty=ty): + r = np.zeros_like(x) + f[1, 1](r, x) + np.testing.assert_equal(r, x) + + # Check PTX contains a cache-policy store instruction + sig = (numba_type, numba_type) + ptx = f.inspect_asm(signature=sig) + + self.assertIn(f"st.global.{modifier}.b{bitwidth}", ptx) + + with self.subTest( + operator=operator, ty=ty, input_type="cpointer" + ): + ptr_type = types.CPointer(numba_type.dtype) + + sig = (ptr_type, numba_type, types.intp) + ptx, _ = cuda.compile_ptx(f_ptr, sig) + + self.assertIn(f"st.global.{modifier}.b{bitwidth}", ptx) + + def test_stores_2d(self): + for operator, modifier in store_operators: + + @cuda.jit + def f(r, x): + for i in range(x.shape[0]): + for j in range(x.shape[1]): + operator(r, (i, j), x[i, j]) + + for ty in tested_types: + x = np.arange(12).reshape(3, 4).astype(ty) + numba_type = typeof(x) + bitwidth = numba_type.dtype.bitwidth + + with self.subTest(operator=operator, ty=ty): + r = np.zeros_like(x) + f[1, 1](r, x) + np.testing.assert_equal(r, x) + + sig = (numba_type, numba_type) + ptx = f.inspect_asm(signature=sig) + + self.assertIn(f"st.global.{modifier}.b{bitwidth}", ptx) + + def test_bad_indices(self): + def float_indices(x): + cuda.ldcs(x, 1.0) + + sig_1d = (types.float32[::1],) + + msg = "float64 is not a valid index" + with self.assertRaisesRegex(errors.TypingError, msg): + cuda.compile_ptx(float_indices, sig_1d) + + def too_long_indices(x): + cuda.ldcs(x, (1, 2)) + + msg = "Expected 1 indices, got 2" + with self.assertRaisesRegex(errors.TypingError, msg): + cuda.compile_ptx(too_long_indices, sig_1d) + + def too_short_indices_scalar(x): + cuda.ldcs(x, 1) + + def too_short_indices_tuple(x): + cuda.ldcs(x, (1,)) + + sig_2d = (types.float32[:, ::1],) + + msg = "Expected 2 indices, got a scalar" + with self.assertRaisesRegex(errors.TypingError, msg): + cuda.compile_ptx(too_short_indices_scalar, sig_2d) + + msg = "Expected 2 indices, got 1" + with self.assertRaisesRegex(errors.TypingError, msg): + cuda.compile_ptx(too_short_indices_tuple, sig_2d) + + def test_unsupported_dtypes(self): + def load_complex(r, x): + r[0] = cuda.ldcs(x, 0) + + # complex types are not supported (not integer or float) + numba_type = typeof(np.zeros(1, dtype=np.complex64)) + sig = (numba_type, numba_type) + + msg = "ldcs requires array of integer or float type" + with self.assertRaisesRegex(errors.TypingError, msg): + cuda.compile_ptx(load_complex, sig) + + +if __name__ == "__main__": + unittest.main()