Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions docs/source/reference/cache_hints.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
..
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-2-Clause

.. _cache-hints:

Cache Hints for Memory Operations
=================================

These functions provide explicit control over caching behavior for memory
operations. They generate PTX instructions with cache policy hints that can
optimize specific memory access patterns. All functions support arrays or
pointers with all bitwidths of signed/unsigned integer and floating-point
types.

.. seealso:: `Cache Operators
<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_
in the PTX ISA documentation.

.. function:: numba.cuda.ldca(array, i)

Load element ``i`` from ``array`` with cache-all policy (``ld.global.ca``). This
is the default caching behavior.

.. function:: numba.cuda.ldcg(array, i)

Load element ``i`` from ``array`` with cache-global policy (``ld.global.cg``).
Useful for data shared across thread blocks.

.. function:: numba.cuda.ldcs(array, i)

Load element ``i`` from ``array`` with cache-streaming policy
(``ld.global.cs``). Optimized for streaming data accessed once.

.. function:: numba.cuda.ldlu(array, i)

Load element ``i`` from ``array`` with last-use policy (``ld.global.lu``).
Indicates data is unlikely to be reused.

.. function:: numba.cuda.ldcv(array, i)

Load element ``i`` from ``array`` with cache-volatile policy (``ld.global.cv``).
Used for volatile data that may change externally.

.. function:: numba.cuda.stcg(array, i, value)

Store ``value`` to ``array[i]`` with cache-global policy (``st.global.cg``).
Useful for data shared across thread blocks.

.. function:: numba.cuda.stcs(array, i, value)

Store ``value`` to ``array[i]`` with cache-streaming policy (``st.global.cs``).
Optimized for streaming writes.

.. function:: numba.cuda.stwb(array, i, value)

Store ``value`` to ``array[i]`` with write-back policy (``st.global.wb``). This
is the default caching behavior.

.. function:: numba.cuda.stwt(array, i, value)

Store ``value`` to ``array[i]`` with write-through policy (``st.global.wt``).
Writes through cache hierarchy to memory.
1 change: 1 addition & 0 deletions docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Reference documentation

host.rst
kernel.rst
cache_hints.rst
types.rst
memory.rst
libdevice.rst
Expand Down
27 changes: 27 additions & 0 deletions numba_cuda/numba/cuda/api_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-2-Clause

from numba import types
from numba.core import cgutils
import numpy as np

import functools
Expand Down Expand Up @@ -47,3 +49,28 @@ def _fill_stride_by_order(shape, dtype, order):
else:
raise ValueError("must be either C/F order")
return tuple(strides)


def normalize_indices(context, builder, indty, inds, aryty, valty):
"""
Convert integer indices into tuple of intp
"""
if indty in types.integer_domain:
indty = types.UniTuple(dtype=indty, count=1)
indices = [inds]
else:
indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
indices = [
context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices)
]

dtype = aryty.dtype
if dtype != valty:
raise TypeError("expect %s but got %s" % (dtype, valty))

if aryty.ndim != len(indty):
raise TypeError(
"indexing %d-D array with %d-D index" % (aryty.ndim, len(indty))
)

return indty, indices
287 changes: 287 additions & 0 deletions numba_cuda/numba/cuda/cache_hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-2-Clause

from llvmlite import ir
from numba import types
from numba.cuda import cgutils
from numba.cuda.extending import intrinsic, overload
from numba.cuda.core.errors import NumbaTypeError
from numba.cuda.api_util import normalize_indices

# Docs references:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#load-functions-using-cache-hints


def ldca(array, i):
"""Generate a `ld.global.ca` instruction for element `i` of an array."""


def ldcg(array, i):
"""Generate a `ld.global.cg` instruction for element `i` of an array."""


def ldcs(array, i):
"""Generate a `ld.global.cs` instruction for element `i` of an array."""


def ldlu(array, i):
"""Generate a `ld.global.lu` instruction for element `i` of an array."""


def ldcv(array, i):
"""Generate a `ld.global.cv` instruction for element `i` of an array."""


def stcg(array, i, value):
"""Generate a `st.global.cg` instruction for element `i` of an array."""


def stcs(array, i, value):
"""Generate a `st.global.cs` instruction for element `i` of an array."""


def stwb(array, i, value):
"""Generate a `st.global.wb` instruction for element `i` of an array."""


def stwt(array, i, value):
"""Generate a `st.global.wt` instruction for element `i` of an array."""


# See
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#restricted-use-of-sub-word-sizes
# for background on the choice of "r" for 8-bit operands - there is
# no constraint for 8-bit operands, but the operand for loads and
# stores is permitted to be greater than 8 bits.
CONSTRAINT_MAP = {1: "b", 8: "r", 16: "h", 32: "r", 64: "l", 128: "q"}


def _validate_arguments(instruction, array, index):
is_array = isinstance(array, types.Array)
is_pointer = isinstance(array, types.CPointer)
if not (is_array or is_pointer):
msg = f"{instruction} operates on arrays or pointers. Got type {array}"
raise NumbaTypeError(msg)

valid_index = False

if isinstance(index, types.Integer):
if is_array and array.ndim != 1:
# for pointers, any integer index is valid
msg = f"Expected {array.ndim} indices, got a scalar"
raise NumbaTypeError(msg)
valid_index = True

if isinstance(index, types.UniTuple):
if is_pointer:
msg = f"Pointers only support scalar indexing, got tuple of {index.count}"
raise NumbaTypeError(msg)

if index.count != array.ndim:
msg = f"Expected {array.ndim} indices, got {index.count}"
raise NumbaTypeError(msg)

if isinstance(index.dtype, types.Integer):
valid_index = True

if not valid_index:
raise NumbaTypeError(f"{index} is not a valid index")


def _validate_bitwidth(instruction, array):
dtype = array.dtype

if not isinstance(dtype, (types.Integer, types.Float)):
msg = (
f"{instruction} requires array of integer or float type, "
f"got {dtype}"
)
raise NumbaTypeError(msg)

bitwidth = dtype.bitwidth
if bitwidth not in CONSTRAINT_MAP:
valid_widths = sorted(CONSTRAINT_MAP.keys())
msg = (
f"{instruction} requires array dtype with bitwidth "
f"in {valid_widths}, got bitwidth {bitwidth}"
)
raise NumbaTypeError(msg)


def _get_element_pointer(
context, builder, index_type, indices, array_type, array
):
if isinstance(array_type, types.CPointer):
return builder.gep(array, [indices])
else:
index_type, indices = normalize_indices(
context,
builder,
index_type,
indices,
array_type,
array_type.dtype,
)
array_struct = context.make_array(array_type)(
context, builder, value=array
)
return cgutils.get_item_pointer(
context,
builder,
array_type,
array_struct,
indices,
wraparound=True,
)


def ld_cache_operator(operator):
@intrinsic
def impl(typingctx, array, index):
_validate_arguments(f"ld{operator}", array, index)
_validate_bitwidth(f"ld{operator}", array)

signature = array.dtype(array, index)

def codegen(context, builder, sig, args):
array_type, index_type = sig.args
loaded_type = context.get_value_type(array_type.dtype)
ptr_type = loaded_type.as_pointer()
ldcs_type = ir.FunctionType(loaded_type, [ptr_type])

array, indices = args

ptr = _get_element_pointer(
context, builder, index_type, indices, array_type, array
)

bitwidth = array_type.dtype.bitwidth
inst = f"ld.global.{operator}.b{bitwidth}"
constraints = f"={CONSTRAINT_MAP[bitwidth]},l"
ldcs = ir.InlineAsm(ldcs_type, f"{inst} $0, [$1];", constraints)
return builder.call(ldcs, [ptr])

return signature, codegen

return impl


ldca_intrinsic = ld_cache_operator("ca")
ldcg_intrinsic = ld_cache_operator("cg")
ldcs_intrinsic = ld_cache_operator("cs")
ldlu_intrinsic = ld_cache_operator("lu")
ldcv_intrinsic = ld_cache_operator("cv")


def st_cache_operator(operator):
@intrinsic
def impl(typingctx, array, index, value):
_validate_arguments(f"st{operator}", array, index)
_validate_bitwidth(f"st{operator}", array)

signature = types.void(array, index, value)

def codegen(context, builder, sig, args):
array_type, index_type, value_type = sig.args
stored_type = context.get_value_type(array_type.dtype)
ptr_type = stored_type.as_pointer()
stcs_type = ir.FunctionType(ir.VoidType(), [ptr_type, stored_type])

array, indices, value = args

ptr = _get_element_pointer(
context, builder, index_type, indices, array_type, array
)

casted_value = context.cast(
builder, value, value_type, array_type.dtype
)

bitwidth = array_type.dtype.bitwidth
inst = f"st.global.{operator}.b{bitwidth}"
constraints = f"l,{CONSTRAINT_MAP[bitwidth]},~{{memory}}"
stcs = ir.InlineAsm(stcs_type, f"{inst} [$0], $1;", constraints)
builder.call(stcs, [ptr, casted_value])

return signature, codegen

return impl


stcg_intrinsic = st_cache_operator("cg")
stcs_intrinsic = st_cache_operator("cs")
stwb_intrinsic = st_cache_operator("wb")
stwt_intrinsic = st_cache_operator("wt")


@overload(ldca, target="cuda")
def ol_ldca(array, i):
def impl(array, i):
return ldca_intrinsic(array, i)

return impl


@overload(ldcg, target="cuda")
def ol_ldcg(array, i):
def impl(array, i):
return ldcg_intrinsic(array, i)

return impl


@overload(ldcs, target="cuda")
def ol_ldcs(array, i):
def impl(array, i):
return ldcs_intrinsic(array, i)

return impl


@overload(ldlu, target="cuda")
def ol_ldlu(array, i):
def impl(array, i):
return ldlu_intrinsic(array, i)

return impl


@overload(ldcv, target="cuda")
def ol_ldcv(array, i):
def impl(array, i):
return ldcv_intrinsic(array, i)

return impl


@overload(stcg, target="cuda")
def ol_stcg(array, i, value):
def impl(array, i, value):
return stcg_intrinsic(array, i, value)

return impl


@overload(stcs, target="cuda")
def ol_stcs(array, i, value):
def impl(array, i, value):
return stcs_intrinsic(array, i, value)

return impl


@overload(stwb, target="cuda")
def ol_stwb(array, i, value):
def impl(array, i, value):
return stwb_intrinsic(array, i, value)

return impl


@overload(stwt, target="cuda")
def ol_stwt(array, i, value):
def impl(array, i, value):
return stwt_intrinsic(array, i, value)

return impl
Loading
Loading