Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 122 additions & 21 deletions python/cuda_cccl/cuda/compute/_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,137 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import functools
import types
from typing import Any, Callable, Hashable

import numpy as np

try:
from cuda.core import Device
except ImportError:
from cuda.core.experimental import Device

from ._utils.protocols import get_dtype
from .typing import DeviceArrayLike, GpuStruct

# Registry thet maps type -> key function for extracting cache key
# from a value of that type.
_KEY_FUNCTIONS: dict[type, Callable[[Any], Hashable]] = {}


def _key_for(value: Any) -> Hashable:
"""
Extract a cache key from a value using the registered KEY_FUNCTIONS.

This function checks the type of the value and delegates to the
appropriate registered keyer. Falls back to using the value
directly if no keyer is registered.

Args:
value: The value to extract a cache key from

Returns:
A hashable cache key
"""
# Handle sequences (lists, tuples) by recursively converting to tuple
if isinstance(value, (list, tuple)):
return tuple(_key_for(item) for item in value)

# Check for exact type match first
value_type = type(value)
if value_type in _KEY_FUNCTIONS:
return _KEY_FUNCTIONS[value_type](value)

# Check for instance match (handles inheritance)
for registered_type, keyer in _KEY_FUNCTIONS.items():
if isinstance(value, registered_type):
return keyer(value)

# Fallback: use value directly (assumes it's hashable)
return value


def _make_cache_key_from_args(*args, **kwargs) -> tuple:
"""
Create a cache key from function arguments.

Args:
*args: Positional arguments
**kwargs: Keyword arguments

Returns:
A tuple containing the extracted cache keys
"""

positional_keys = tuple(_key_for(arg) for arg in args)

# Sort kwargs by key name for consistent ordering
if kwargs:
sorted_kwargs = sorted(kwargs.items())
kwarg_keys = tuple((k, _key_for(v)) for k, v in sorted_kwargs)
return positional_keys + (kwarg_keys,)

return positional_keys


# Central registry of all algorithm caches
_cache_registry: dict[str, object] = {}


def cache_with_key(key):
class _CacheWithRegisteredKeyFunctions:
"""
Decorator to cache the result of the decorated function. Uses the
provided `key` function to compute the key for cache lookup. `key`
receives all arguments passed to the function.
Decorator to cache the result of the decorated function.

Notes
-----
The CUDA compute capability of the current device is appended to
the cache key returned by `key`.

The decorated function is automatically registered in the central
cache registry for easy cache management.
The cache key is automatically computed from the decorated function's
arguments using the registered key functions.
"""

def deco(func):
cache = {}
def __call__(self, func: Callable) -> Callable:
"""
Decorator to cache the result of the decorated function.

Args:
func: The function whose result is to be cached.

Notes
-----
The CUDA compute capability of the current device is appended to
the cache key.
"""
cache: dict = {}

@functools.wraps(func)
def inner(*args, **kwargs):
cc = Device().compute_capability
cache_key = (key(*args, **kwargs), tuple(cc))
user_cache_key = _make_cache_key_from_args(*args, **kwargs)
cache_key = (user_cache_key, tuple(cc))
if cache_key not in cache:
result = func(*args, **kwargs)
cache[cache_key] = result
return cache[cache_key]

def cache_clear():
cache.clear()

inner.cache_clear = cache_clear
inner.cache_clear = cache.clear # type: ignore[attr-defined]

# Register the cache in the central registry
cache_name = func.__qualname__
_cache_registry[cache_name] = inner
_cache_registry[func.__qualname__] = inner

return inner

return deco
def register(self, type_: type, key_function: Callable[[Any], Hashable]) -> None:
"""
Register a key function for a specific type.

A key function extracts a hashable cache key from a value.

Args:
type_: The type to register
key_function: A callable that takes an instance of type_ and
returns a hashable cache key
"""
_KEY_FUNCTIONS[type_] = key_function


cache_with_registered_key_functions = _CacheWithRegisteredKeyFunctions()


def _hash_device_array_like(value):
Expand Down Expand Up @@ -138,3 +219,23 @@ def __hash__(self):

def __repr__(self):
return str(self._func)


# Register keyers for built-in types
# Include fully-qualified type name to distinguish np.ndarray from cp.ndarray from GpuStruct
def _type_fqn(v):
return f"{type(v).__module__}.{type(v).__name__}"


cache_with_registered_key_functions.register(
np.ndarray, lambda arr: ("numpy.ndarray", arr.dtype)
)
cache_with_registered_key_functions.register(
types.FunctionType, lambda fn: CachableFunction(fn)
)
cache_with_registered_key_functions.register(
DeviceArrayLike, lambda v: (_type_fqn(v), get_dtype(v))
)
cache_with_registered_key_functions.register(
GpuStruct, lambda v: (_type_fqn(v), v.dtype)
)
33 changes: 3 additions & 30 deletions python/cuda_cccl/cuda/compute/algorithms/_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,14 @@

from .. import _bindings
from .. import _cccl_interop as cccl
from .._caching import cache_with_key
from .._caching import cache_with_registered_key_functions
from .._cccl_interop import call_build, set_cccl_iterator_state, to_cccl_value_state
from .._utils.protocols import get_data_pointer, get_dtype, validate_and_get_stream
from .._utils.protocols import get_data_pointer, validate_and_get_stream
from .._utils.temp_storage_buffer import TempStorageBuffer
from ..iterators._iterators import IteratorBase
from ..typing import DeviceArrayLike


def make_cache_key(
d_samples: DeviceArrayLike | IteratorBase,
d_histogram: DeviceArrayLike,
d_num_output_levels: DeviceArrayLike,
h_lower_level: np.ndarray,
h_upper_level: np.ndarray,
num_samples: int,
):
d_samples_key = (
d_samples.kind if isinstance(d_samples, IteratorBase) else get_dtype(d_samples)
)

d_histogram_key = get_dtype(d_histogram)
d_num_output_levels_key = get_dtype(d_num_output_levels)
d_lower_level_key = h_lower_level.dtype
d_upper_level_key = h_upper_level.dtype

return (
d_samples_key,
d_histogram_key,
d_num_output_levels_key,
d_lower_level_key,
d_upper_level_key,
num_samples,
)


class _Histogram:
__slots__ = [
"num_rows",
Expand Down Expand Up @@ -134,7 +107,7 @@ def __call__(
return temp_storage_bytes


@cache_with_key(make_cache_key)
@cache_with_registered_key_functions
def make_histogram_even(
d_samples: DeviceArrayLike | IteratorBase,
d_histogram: DeviceArrayLike,
Expand Down
46 changes: 9 additions & 37 deletions python/cuda_cccl/cuda/compute/algorithms/_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@

from .. import _bindings
from .. import _cccl_interop as cccl
from .._caching import cache_with_key
from .._caching import cache_with_registered_key_functions
from .._cccl_interop import (
call_build,
get_value_type,
set_cccl_iterator_state,
to_cccl_value_state,
)
from .._utils import protocols
from .._utils.protocols import get_data_pointer, validate_and_get_stream
from .._utils.temp_storage_buffer import TempStorageBuffer
from ..determinism import Determinism
Expand Down Expand Up @@ -106,40 +105,7 @@ def __call__(
return temp_storage_bytes


def _make_cache_key(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike | IteratorBase,
op: OpAdapter,
h_init: np.ndarray | GpuStruct,
**kwargs,
):
d_in_key = (
d_in.kind if isinstance(d_in, IteratorBase) else protocols.get_dtype(d_in)
)
d_out_key = (
d_out.kind if isinstance(d_out, IteratorBase) else protocols.get_dtype(d_out)
)
h_init_key = h_init.dtype
determinism = kwargs.get("determinism", Determinism.RUN_TO_RUN)
return (d_in_key, d_out_key, op.get_cache_key(), h_init_key, determinism)


@cache_with_key(_make_cache_key)
def _make_reduce_into_cached(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike | IteratorBase,
op: OpAdapter,
h_init: np.ndarray | GpuStruct,
**kwargs,
):
"""Internal cached factory for _Reduce."""
return _Reduce(
d_in, d_out, op, h_init, kwargs.get("determinism", Determinism.RUN_TO_RUN)
)


# TODO Figure out `sum` without operator and initial value
# TODO Accept stream
@cache_with_registered_key_functions
def make_reduce_into(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike | IteratorBase,
Expand Down Expand Up @@ -167,7 +133,13 @@ def make_reduce_into(
A callable object that can be used to perform the reduction
"""
op_adapter = make_op_adapter(op)
return _make_reduce_into_cached(d_in, d_out, op_adapter, h_init, **kwargs)
return _Reduce(
d_in,
d_out,
op_adapter,
h_init,
kwargs.get("determinism", Determinism.RUN_TO_RUN),
)


def reduce_into(
Expand Down
Loading