diff --git a/python/cuda_cccl/cuda/compute/_caching.py b/python/cuda_cccl/cuda/compute/_caching.py index 7c6bcd0ec02..f280376a589 100644 --- a/python/cuda_cccl/cuda/compute/_caching.py +++ b/python/cuda_cccl/cuda/compute/_caching.py @@ -4,56 +4,137 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import functools +import types +from typing import Any, Callable, Hashable + +import numpy as np try: from cuda.core import Device except ImportError: from cuda.core.experimental import Device +from ._utils.protocols import get_dtype +from .typing import DeviceArrayLike, GpuStruct + +# Registry thet maps type -> key function for extracting cache key +# from a value of that type. +_KEY_FUNCTIONS: dict[type, Callable[[Any], Hashable]] = {} + + +def _key_for(value: Any) -> Hashable: + """ + Extract a cache key from a value using the registered KEY_FUNCTIONS. + + This function checks the type of the value and delegates to the + appropriate registered keyer. Falls back to using the value + directly if no keyer is registered. + + Args: + value: The value to extract a cache key from + + Returns: + A hashable cache key + """ + # Handle sequences (lists, tuples) by recursively converting to tuple + if isinstance(value, (list, tuple)): + return tuple(_key_for(item) for item in value) + + # Check for exact type match first + value_type = type(value) + if value_type in _KEY_FUNCTIONS: + return _KEY_FUNCTIONS[value_type](value) + + # Check for instance match (handles inheritance) + for registered_type, keyer in _KEY_FUNCTIONS.items(): + if isinstance(value, registered_type): + return keyer(value) + + # Fallback: use value directly (assumes it's hashable) + return value + + +def _make_cache_key_from_args(*args, **kwargs) -> tuple: + """ + Create a cache key from function arguments. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + A tuple containing the extracted cache keys + """ + + positional_keys = tuple(_key_for(arg) for arg in args) + + # Sort kwargs by key name for consistent ordering + if kwargs: + sorted_kwargs = sorted(kwargs.items()) + kwarg_keys = tuple((k, _key_for(v)) for k, v in sorted_kwargs) + return positional_keys + (kwarg_keys,) + + return positional_keys + # Central registry of all algorithm caches _cache_registry: dict[str, object] = {} -def cache_with_key(key): +class _CacheWithRegisteredKeyFunctions: """ - Decorator to cache the result of the decorated function. Uses the - provided `key` function to compute the key for cache lookup. `key` - receives all arguments passed to the function. + Decorator to cache the result of the decorated function. - Notes - ----- - The CUDA compute capability of the current device is appended to - the cache key returned by `key`. - - The decorated function is automatically registered in the central - cache registry for easy cache management. + The cache key is automatically computed from the decorated function's + arguments using the registered key functions. """ - def deco(func): - cache = {} + def __call__(self, func: Callable) -> Callable: + """ + Decorator to cache the result of the decorated function. + + Args: + func: The function whose result is to be cached. + + Notes + ----- + The CUDA compute capability of the current device is appended to + the cache key. + """ + cache: dict = {} @functools.wraps(func) def inner(*args, **kwargs): cc = Device().compute_capability - cache_key = (key(*args, **kwargs), tuple(cc)) + user_cache_key = _make_cache_key_from_args(*args, **kwargs) + cache_key = (user_cache_key, tuple(cc)) if cache_key not in cache: result = func(*args, **kwargs) cache[cache_key] = result return cache[cache_key] - def cache_clear(): - cache.clear() - - inner.cache_clear = cache_clear + inner.cache_clear = cache.clear # type: ignore[attr-defined] # Register the cache in the central registry - cache_name = func.__qualname__ - _cache_registry[cache_name] = inner + _cache_registry[func.__qualname__] = inner return inner - return deco + def register(self, type_: type, key_function: Callable[[Any], Hashable]) -> None: + """ + Register a key function for a specific type. + + A key function extracts a hashable cache key from a value. + + Args: + type_: The type to register + key_function: A callable that takes an instance of type_ and + returns a hashable cache key + """ + _KEY_FUNCTIONS[type_] = key_function + + +cache_with_registered_key_functions = _CacheWithRegisteredKeyFunctions() def _hash_device_array_like(value): @@ -138,3 +219,23 @@ def __hash__(self): def __repr__(self): return str(self._func) + + +# Register keyers for built-in types +# Include fully-qualified type name to distinguish np.ndarray from cp.ndarray from GpuStruct +def _type_fqn(v): + return f"{type(v).__module__}.{type(v).__name__}" + + +cache_with_registered_key_functions.register( + np.ndarray, lambda arr: ("numpy.ndarray", arr.dtype) +) +cache_with_registered_key_functions.register( + types.FunctionType, lambda fn: CachableFunction(fn) +) +cache_with_registered_key_functions.register( + DeviceArrayLike, lambda v: (_type_fqn(v), get_dtype(v)) +) +cache_with_registered_key_functions.register( + GpuStruct, lambda v: (_type_fqn(v), v.dtype) +) diff --git a/python/cuda_cccl/cuda/compute/algorithms/_histogram.py b/python/cuda_cccl/cuda/compute/algorithms/_histogram.py index 9e15d3c3ff0..075ee820dbb 100644 --- a/python/cuda_cccl/cuda/compute/algorithms/_histogram.py +++ b/python/cuda_cccl/cuda/compute/algorithms/_histogram.py @@ -9,41 +9,14 @@ from .. import _bindings from .. import _cccl_interop as cccl -from .._caching import cache_with_key +from .._caching import cache_with_registered_key_functions from .._cccl_interop import call_build, set_cccl_iterator_state, to_cccl_value_state -from .._utils.protocols import get_data_pointer, get_dtype, validate_and_get_stream +from .._utils.protocols import get_data_pointer, validate_and_get_stream from .._utils.temp_storage_buffer import TempStorageBuffer from ..iterators._iterators import IteratorBase from ..typing import DeviceArrayLike -def make_cache_key( - d_samples: DeviceArrayLike | IteratorBase, - d_histogram: DeviceArrayLike, - d_num_output_levels: DeviceArrayLike, - h_lower_level: np.ndarray, - h_upper_level: np.ndarray, - num_samples: int, -): - d_samples_key = ( - d_samples.kind if isinstance(d_samples, IteratorBase) else get_dtype(d_samples) - ) - - d_histogram_key = get_dtype(d_histogram) - d_num_output_levels_key = get_dtype(d_num_output_levels) - d_lower_level_key = h_lower_level.dtype - d_upper_level_key = h_upper_level.dtype - - return ( - d_samples_key, - d_histogram_key, - d_num_output_levels_key, - d_lower_level_key, - d_upper_level_key, - num_samples, - ) - - class _Histogram: __slots__ = [ "num_rows", @@ -134,7 +107,7 @@ def __call__( return temp_storage_bytes -@cache_with_key(make_cache_key) +@cache_with_registered_key_functions def make_histogram_even( d_samples: DeviceArrayLike | IteratorBase, d_histogram: DeviceArrayLike, diff --git a/python/cuda_cccl/cuda/compute/algorithms/_reduce.py b/python/cuda_cccl/cuda/compute/algorithms/_reduce.py index d527a40a4e9..3605b0bcad7 100644 --- a/python/cuda_cccl/cuda/compute/algorithms/_reduce.py +++ b/python/cuda_cccl/cuda/compute/algorithms/_reduce.py @@ -9,14 +9,13 @@ from .. import _bindings from .. import _cccl_interop as cccl -from .._caching import cache_with_key +from .._caching import cache_with_registered_key_functions from .._cccl_interop import ( call_build, get_value_type, set_cccl_iterator_state, to_cccl_value_state, ) -from .._utils import protocols from .._utils.protocols import get_data_pointer, validate_and_get_stream from .._utils.temp_storage_buffer import TempStorageBuffer from ..determinism import Determinism @@ -106,40 +105,7 @@ def __call__( return temp_storage_bytes -def _make_cache_key( - d_in: DeviceArrayLike | IteratorBase, - d_out: DeviceArrayLike | IteratorBase, - op: OpAdapter, - h_init: np.ndarray | GpuStruct, - **kwargs, -): - d_in_key = ( - d_in.kind if isinstance(d_in, IteratorBase) else protocols.get_dtype(d_in) - ) - d_out_key = ( - d_out.kind if isinstance(d_out, IteratorBase) else protocols.get_dtype(d_out) - ) - h_init_key = h_init.dtype - determinism = kwargs.get("determinism", Determinism.RUN_TO_RUN) - return (d_in_key, d_out_key, op.get_cache_key(), h_init_key, determinism) - - -@cache_with_key(_make_cache_key) -def _make_reduce_into_cached( - d_in: DeviceArrayLike | IteratorBase, - d_out: DeviceArrayLike | IteratorBase, - op: OpAdapter, - h_init: np.ndarray | GpuStruct, - **kwargs, -): - """Internal cached factory for _Reduce.""" - return _Reduce( - d_in, d_out, op, h_init, kwargs.get("determinism", Determinism.RUN_TO_RUN) - ) - - -# TODO Figure out `sum` without operator and initial value -# TODO Accept stream +@cache_with_registered_key_functions def make_reduce_into( d_in: DeviceArrayLike | IteratorBase, d_out: DeviceArrayLike | IteratorBase, @@ -167,7 +133,13 @@ def make_reduce_into( A callable object that can be used to perform the reduction """ op_adapter = make_op_adapter(op) - return _make_reduce_into_cached(d_in, d_out, op_adapter, h_init, **kwargs) + return _Reduce( + d_in, + d_out, + op_adapter, + h_init, + kwargs.get("determinism", Determinism.RUN_TO_RUN), + ) def reduce_into( diff --git a/python/cuda_cccl/cuda/compute/algorithms/_scan.py b/python/cuda_cccl/cuda/compute/algorithms/_scan.py index ef9861aca22..b6082374e28 100644 --- a/python/cuda_cccl/cuda/compute/algorithms/_scan.py +++ b/python/cuda_cccl/cuda/compute/algorithms/_scan.py @@ -10,7 +10,7 @@ from .. import _bindings from .. import _cccl_interop as cccl -from .._caching import cache_with_key +from .._caching import cache_with_registered_key_functions from .._cccl_interop import ( call_build, get_value_type, @@ -164,56 +164,9 @@ def __call__( return temp_storage_bytes -def _make_cache_key( - d_in: DeviceArrayLike | IteratorBase, - d_out: DeviceArrayLike | IteratorBase, - op: OpAdapter, - init_value: np.ndarray | DeviceArrayLike | GpuStruct | None, -): - d_in_key = ( - d_in.kind if isinstance(d_in, IteratorBase) else protocols.get_dtype(d_in) - ) - d_out_key = ( - d_out.kind if isinstance(d_out, IteratorBase) else protocols.get_dtype(d_out) - ) - - init_kind_key = get_init_kind(init_value) - match init_kind_key: - case _bindings.InitKind.NO_INIT: - init_value_key = None - case _bindings.InitKind.FUTURE_VALUE_INIT: - init_value_key = protocols.get_dtype(cast(DeviceArrayLike, init_value)) - case _bindings.InitKind.VALUE_INIT: - init_value = cast(np.ndarray | GpuStruct, init_value) - init_value_key = init_value.dtype - - return (d_in_key, d_out_key, op.get_cache_key(), init_value_key, init_kind_key) - - -@cache_with_key(_make_cache_key) -def _make_exclusive_scan_cached( - d_in: DeviceArrayLike | IteratorBase, - d_out: DeviceArrayLike | IteratorBase, - op: OpAdapter, - init_value: np.ndarray | DeviceArrayLike | GpuStruct | None, -): - """Internal cached factory for exclusive _Scan.""" - return _Scan(d_in, d_out, op, init_value, False) - - -@cache_with_key(_make_cache_key) -def _make_inclusive_scan_cached( - d_in: DeviceArrayLike | IteratorBase, - d_out: DeviceArrayLike | IteratorBase, - op: OpAdapter, - init_value: np.ndarray | DeviceArrayLike | GpuStruct | None, -): - """Internal cached factory for inclusive _Scan.""" - return _Scan(d_in, d_out, op, init_value, True) - - # TODO Figure out `sum` without operator and initial value # TODO Accept stream +@cache_with_registered_key_functions def make_exclusive_scan( d_in: DeviceArrayLike | IteratorBase, d_out: DeviceArrayLike | IteratorBase, @@ -240,7 +193,7 @@ def make_exclusive_scan( A callable object that can be used to perform the scan """ op_adapter = make_op_adapter(op) - return _make_exclusive_scan_cached(d_in, d_out, op_adapter, init_value) + return _Scan(d_in, d_out, op_adapter, init_value, False) def exclusive_scan( @@ -280,6 +233,7 @@ def exclusive_scan( # TODO Figure out `sum` without operator and initial value # TODO Accept stream +@cache_with_registered_key_functions def make_inclusive_scan( d_in: DeviceArrayLike | IteratorBase, d_out: DeviceArrayLike | IteratorBase, @@ -306,7 +260,7 @@ def make_inclusive_scan( A callable object that can be used to perform the scan """ op_adapter = make_op_adapter(op) - return _make_inclusive_scan_cached(d_in, d_out, op_adapter, init_value) + return _Scan(d_in, d_out, op_adapter, init_value, True) def inclusive_scan( diff --git a/python/cuda_cccl/cuda/compute/algorithms/_segmented_reduce.py b/python/cuda_cccl/cuda/compute/algorithms/_segmented_reduce.py index ff50b294b69..00e39c4efcd 100644 --- a/python/cuda_cccl/cuda/compute/algorithms/_segmented_reduce.py +++ b/python/cuda_cccl/cuda/compute/algorithms/_segmented_reduce.py @@ -4,14 +4,13 @@ from .. import _bindings from .. import _cccl_interop as cccl -from .._caching import cache_with_key +from .._caching import cache_with_registered_key_functions from .._cccl_interop import ( call_build, get_value_type, set_cccl_iterator_state, to_cccl_value_state, ) -from .._utils import protocols from .._utils.protocols import ( get_data_pointer, validate_and_get_stream, @@ -125,52 +124,7 @@ def __call__( return temp_storage_bytes -def _to_key(d_in: DeviceArrayLike | IteratorBase): - "Return key for an input array-like argument or an iterator" - d_in_key = ( - d_in.kind if isinstance(d_in, IteratorBase) else protocols.get_dtype(d_in) - ) - return d_in_key - - -def _make_cache_key( - d_in: DeviceArrayLike | IteratorBase, - d_out: DeviceArrayLike | IteratorBase, - start_offsets_in: DeviceArrayLike | IteratorBase, - end_offsets_in: DeviceArrayLike | IteratorBase, - op: OpAdapter, - h_init: np.ndarray | GpuStruct, -): - d_in_key = _to_key(d_in) - d_out_key = ( - d_out.kind if isinstance(d_out, IteratorBase) else protocols.get_dtype(d_out) - ) - start_offsets_in_key = _to_key(start_offsets_in) - end_offsets_in_key = _to_key(end_offsets_in) - h_init_key = h_init.dtype - return ( - d_in_key, - d_out_key, - start_offsets_in_key, - end_offsets_in_key, - op.get_cache_key(), - h_init_key, - ) - - -@cache_with_key(_make_cache_key) -def _make_segmented_reduce_cached( - d_in: DeviceArrayLike | IteratorBase, - d_out: DeviceArrayLike | IteratorBase, - start_offsets_in: DeviceArrayLike | IteratorBase, - end_offsets_in: DeviceArrayLike | IteratorBase, - op: OpAdapter, - h_init: np.ndarray | GpuStruct, -): - """Internal cached factory for _SegmentedReduce.""" - return _SegmentedReduce(d_in, d_out, start_offsets_in, end_offsets_in, op, h_init) - - +@cache_with_registered_key_functions def make_segmented_reduce( d_in: DeviceArrayLike | IteratorBase, d_out: DeviceArrayLike | IteratorBase, @@ -201,7 +155,7 @@ def make_segmented_reduce( A callable object that can be used to perform the reduction """ op_adapter = make_op_adapter(op) - return _make_segmented_reduce_cached( + return _SegmentedReduce( d_in, d_out, start_offsets_in, end_offsets_in, op_adapter, h_init ) diff --git a/python/cuda_cccl/cuda/compute/algorithms/_select.py b/python/cuda_cccl/cuda/compute/algorithms/_select.py index b0b939f5092..b6a3de7dea6 100644 --- a/python/cuda_cccl/cuda/compute/algorithms/_select.py +++ b/python/cuda_cccl/cuda/compute/algorithms/_select.py @@ -5,8 +5,7 @@ from typing import Callable -from .._caching import cache_with_key -from .._utils import protocols +from .._caching import cache_with_registered_key_functions from .._utils.temp_storage_buffer import TempStorageBuffer from ..iterators._factories import DiscardIterator from ..iterators._iterators import IteratorBase @@ -15,23 +14,6 @@ from ._three_way_partition import make_three_way_partition -def _make_cache_key( - d_in: DeviceArrayLike | IteratorBase, - d_out: DeviceArrayLike | IteratorBase, - d_num_selected_out: DeviceArrayLike, - cond: OpAdapter, -): - d_in_key = ( - d_in.kind if isinstance(d_in, IteratorBase) else protocols.get_dtype(d_in) - ) - d_out_key = ( - d_out.kind if isinstance(d_out, IteratorBase) else protocols.get_dtype(d_out) - ) - d_num_selected_out_key = protocols.get_dtype(d_num_selected_out) - - return (d_in_key, d_out_key, d_num_selected_out_key, cond.get_cache_key()) - - class _Select: __slots__ = ["partitioner", "discard_second", "discard_unselected"] @@ -79,20 +61,7 @@ def __call__( ) -@cache_with_key(_make_cache_key) -def _make_select_cached( - d_in: DeviceArrayLike | IteratorBase, - d_out: DeviceArrayLike | IteratorBase, - d_num_selected_out: DeviceArrayLike, - cond: OpAdapter, -): - """Internal cached factory for _Select.""" - # Note: _Select internally calls make_three_way_partition which will - # normalize the cond. But we've already normalized it, so the Op - # will be passed through make_op unchanged. - return _Select(d_in, d_out, d_num_selected_out, cond) - - +@cache_with_registered_key_functions def make_select( d_in: DeviceArrayLike | IteratorBase, d_out: DeviceArrayLike | IteratorBase, @@ -124,7 +93,10 @@ def make_select( A callable object that performs the selection operation. """ cond_adapter = make_op_adapter(cond) - return _make_select_cached(d_in, d_out, d_num_selected_out, cond_adapter) + # Note: _Select internally calls make_three_way_partition which will + # normalize the cond. But we've already normalized it, so the Op + # will be passed through make_op unchanged. + return _Select(d_in, d_out, d_num_selected_out, cond_adapter) def select( diff --git a/python/cuda_cccl/cuda/compute/algorithms/_sort/_merge_sort.py b/python/cuda_cccl/cuda/compute/algorithms/_sort/_merge_sort.py index 84d5134c537..0092aed4aed 100644 --- a/python/cuda_cccl/cuda/compute/algorithms/_sort/_merge_sort.py +++ b/python/cuda_cccl/cuda/compute/algorithms/_sort/_merge_sort.py @@ -9,9 +9,8 @@ from ... import _bindings from ... import _cccl_interop as cccl -from ..._caching import cache_with_key +from ..._caching import cache_with_registered_key_functions from ..._cccl_interop import call_build, set_cccl_iterator_state -from ..._utils import protocols from ..._utils.protocols import ( get_data_pointer, validate_and_get_stream, @@ -22,45 +21,6 @@ from ...typing import DeviceArrayLike -def _make_cache_key( - d_in_keys: DeviceArrayLike | IteratorBase, - d_in_items: DeviceArrayLike | IteratorBase | None, - d_out_keys: DeviceArrayLike, - d_out_items: DeviceArrayLike | None, - op: OpAdapter, -): - d_in_keys_key = ( - d_in_keys.kind - if isinstance(d_in_keys, IteratorBase) - else protocols.get_dtype(d_in_keys) - ) - if d_in_items is None: - d_in_items_key = None - else: - d_in_items_key = ( - d_in_items.kind - if isinstance(d_in_items, IteratorBase) - else protocols.get_dtype(d_in_items) - ) - d_out_keys_key = protocols.get_dtype(d_out_keys) - if d_out_items is None: - d_out_items_key = None - else: - d_out_items_key = ( - d_out_items.kind - if isinstance(d_out_items, IteratorBase) - else protocols.get_dtype(d_out_items) - ) - - return ( - d_in_keys_key, - d_in_items_key, - d_out_keys_key, - d_out_items_key, - op.get_cache_key(), - ) - - class _MergeSort: __slots__ = [ "d_in_keys_cccl", @@ -149,18 +109,7 @@ def __call__( return temp_storage_bytes -@cache_with_key(_make_cache_key) -def _make_merge_sort_cached( - d_in_keys: DeviceArrayLike | IteratorBase, - d_in_items: DeviceArrayLike | IteratorBase | None, - d_out_keys: DeviceArrayLike, - d_out_items: DeviceArrayLike | None, - op: OpAdapter, -): - """Internal cached factory for _MergeSort.""" - return _MergeSort(d_in_keys, d_in_items, d_out_keys, d_out_items, op) - - +@cache_with_registered_key_functions def make_merge_sort( d_in_keys: DeviceArrayLike | IteratorBase, d_in_items: DeviceArrayLike | IteratorBase | None, @@ -189,9 +138,7 @@ def make_merge_sort( A callable object that can be used to perform the merge sort """ op_adapter = make_op_adapter(op) - return _make_merge_sort_cached( - d_in_keys, d_in_items, d_out_keys, d_out_items, op_adapter - ) + return _MergeSort(d_in_keys, d_in_items, d_out_keys, d_out_items, op_adapter) def merge_sort( diff --git a/python/cuda_cccl/cuda/compute/algorithms/_sort/_radix_sort.py b/python/cuda_cccl/cuda/compute/algorithms/_sort/_radix_sort.py index 1080143018a..ceb9b3e1545 100644 --- a/python/cuda_cccl/cuda/compute/algorithms/_sort/_radix_sort.py +++ b/python/cuda_cccl/cuda/compute/algorithms/_sort/_radix_sort.py @@ -5,7 +5,7 @@ from ... import _bindings from ... import _cccl_interop as cccl -from ..._caching import cache_with_key +from ..._caching import cache_with_registered_key_functions from ..._cccl_interop import call_build, set_cccl_iterator_state from ..._utils.protocols import ( get_data_pointer, @@ -17,35 +17,6 @@ from ._sort_common import DoubleBuffer, SortOrder, _get_arrays -def make_cache_key( - d_in_keys: DeviceArrayLike | DoubleBuffer, - d_out_keys: DeviceArrayLike | None, - d_in_values: DeviceArrayLike | DoubleBuffer | None, - d_out_values: DeviceArrayLike | None, - order: SortOrder, -): - d_in_keys_array, d_out_keys_array, d_in_values_array, d_out_values_array = ( - _get_arrays(d_in_keys, d_out_keys, d_in_values, d_out_values) - ) - - d_in_keys_key = get_dtype(d_in_keys_array) - d_in_values_key = ( - None if d_in_values_array is None else get_dtype(d_in_values_array) - ) - d_out_keys_key = get_dtype(d_out_keys_array) - d_out_values_key = ( - None if d_out_values_array is None else get_dtype(d_out_values_array) - ) - - return ( - d_in_keys_key, - d_out_keys_key, - d_in_values_key, - d_out_values_key, - order, - ) - - class _RadixSort: __slots__ = [ "d_in_keys_cccl", @@ -164,7 +135,7 @@ def __call__( return temp_storage_bytes -@cache_with_key(make_cache_key) +@cache_with_registered_key_functions def make_radix_sort( d_in_keys: DeviceArrayLike | DoubleBuffer, d_out_keys: DeviceArrayLike | None, diff --git a/python/cuda_cccl/cuda/compute/algorithms/_sort/_segmented_sort.py b/python/cuda_cccl/cuda/compute/algorithms/_sort/_segmented_sort.py index 168b53a0f54..004e01e3d0e 100644 --- a/python/cuda_cccl/cuda/compute/algorithms/_sort/_segmented_sort.py +++ b/python/cuda_cccl/cuda/compute/algorithms/_sort/_segmented_sort.py @@ -6,11 +6,10 @@ from ... import _bindings from ... import _cccl_interop as cccl -from ..._caching import cache_with_key +from ..._caching import cache_with_registered_key_functions from ..._cccl_interop import call_build, set_cccl_iterator_state from ..._utils.protocols import ( get_data_pointer, - get_dtype, validate_and_get_stream, ) from ..._utils.temp_storage_buffer import TempStorageBuffer @@ -131,42 +130,7 @@ def __call__( return temp_storage_bytes -def make_cache_key( - d_in_keys: DeviceArrayLike | DoubleBuffer, - d_out_keys: DeviceArrayLike | None, - d_in_values: DeviceArrayLike | DoubleBuffer | None, - d_out_values: DeviceArrayLike | None, - start_offsets_in: DeviceArrayLike, - end_offsets_in: DeviceArrayLike, - order: SortOrder, -): - d_in_keys_array, d_out_keys_array, d_in_values_array, d_out_values_array = ( - _get_arrays(d_in_keys, d_out_keys, d_in_values, d_out_values) - ) - - d_in_keys_key = get_dtype(d_in_keys_array) - d_out_keys_key = None if d_out_keys_array is None else get_dtype(d_out_keys_array) - d_in_values_key = ( - None if d_in_values_array is None else get_dtype(d_in_values_array) - ) - d_out_values_key = ( - None if d_out_values_array is None else get_dtype(d_out_values_array) - ) - start_offsets_in_key = get_dtype(start_offsets_in) - end_offsets_in_key = get_dtype(end_offsets_in) - - return ( - d_in_keys_key, - d_out_keys_key, - d_in_values_key, - d_out_values_key, - start_offsets_in_key, - end_offsets_in_key, - order, - ) - - -@cache_with_key(make_cache_key) +@cache_with_registered_key_functions def make_segmented_sort( d_in_keys: DeviceArrayLike | DoubleBuffer, d_out_keys: DeviceArrayLike | None, diff --git a/python/cuda_cccl/cuda/compute/algorithms/_sort/_sort_common.py b/python/cuda_cccl/cuda/compute/algorithms/_sort/_sort_common.py index 8c2caa505eb..047365f72d0 100644 --- a/python/cuda_cccl/cuda/compute/algorithms/_sort/_sort_common.py +++ b/python/cuda_cccl/cuda/compute/algorithms/_sort/_sort_common.py @@ -6,6 +6,8 @@ from enum import Enum from typing import Tuple +from ..._caching import cache_with_registered_key_functions +from ..._utils.protocols import get_dtype from ...typing import DeviceArrayLike @@ -50,3 +52,9 @@ def _get_arrays( d_out_values_array = d_out_values return d_in_keys_array, d_out_keys_array, d_in_values_array, d_out_values_array + + +# DoubleBuffer: extract dtype from current buffer +cache_with_registered_key_functions.register( + DoubleBuffer, lambda buf: get_dtype(buf.current()) +) diff --git a/python/cuda_cccl/cuda/compute/algorithms/_three_way_partition.py b/python/cuda_cccl/cuda/compute/algorithms/_three_way_partition.py index 9727d6f4ce5..9533765c100 100644 --- a/python/cuda_cccl/cuda/compute/algorithms/_three_way_partition.py +++ b/python/cuda_cccl/cuda/compute/algorithms/_three_way_partition.py @@ -9,7 +9,7 @@ from .. import _bindings from .. import _cccl_interop as cccl -from .._caching import cache_with_key +from .._caching import cache_with_registered_key_functions from .._cccl_interop import call_build, set_cccl_iterator_state from .._utils import protocols from .._utils.temp_storage_buffer import TempStorageBuffer @@ -18,50 +18,6 @@ from ..typing import DeviceArrayLike -def _make_cache_key( - d_in: DeviceArrayLike | IteratorBase, - d_first_part_out: DeviceArrayLike | IteratorBase, - d_second_part_out: DeviceArrayLike | IteratorBase, - d_unselected_out: DeviceArrayLike | IteratorBase, - d_num_selected_out: DeviceArrayLike | IteratorBase, - select_first_part_op: OpAdapter, - select_second_part_op: OpAdapter, -): - d_in_key = ( - d_in.kind if isinstance(d_in, IteratorBase) else protocols.get_dtype(d_in) - ) - d_first_part_out_key = ( - d_first_part_out.kind - if isinstance(d_first_part_out, IteratorBase) - else protocols.get_dtype(d_first_part_out) - ) - d_second_part_out_key = ( - d_second_part_out.kind - if isinstance(d_second_part_out, IteratorBase) - else protocols.get_dtype(d_second_part_out) - ) - d_unselected_out_key = ( - d_unselected_out.kind - if isinstance(d_unselected_out, IteratorBase) - else protocols.get_dtype(d_unselected_out) - ) - d_num_selected_out_key = ( - d_num_selected_out.kind - if isinstance(d_num_selected_out, IteratorBase) - else protocols.get_dtype(d_num_selected_out) - ) - - return ( - d_in_key, - d_first_part_out_key, - d_second_part_out_key, - d_unselected_out_key, - d_num_selected_out_key, - select_first_part_op.get_cache_key(), - select_second_part_op.get_cache_key(), - ) - - class _ThreeWayPartition: __slots__ = [ "build_result", @@ -154,28 +110,7 @@ def __call__( return temp_storage_bytes -@cache_with_key(_make_cache_key) -def _make_three_way_partition_cached( - d_in: DeviceArrayLike | IteratorBase, - d_first_part_out: DeviceArrayLike | IteratorBase, - d_second_part_out: DeviceArrayLike | IteratorBase, - d_unselected_out: DeviceArrayLike | IteratorBase, - d_num_selected_out: DeviceArrayLike | IteratorBase, - select_first_part_op: OpAdapter, - select_second_part_op: OpAdapter, -): - """Internal cached factory for _ThreeWayPartition.""" - return _ThreeWayPartition( - d_in, - d_first_part_out, - d_second_part_out, - d_unselected_out, - d_num_selected_out, - select_first_part_op, - select_second_part_op, - ) - - +@cache_with_registered_key_functions def make_three_way_partition( d_in: DeviceArrayLike | IteratorBase, d_first_part_out: DeviceArrayLike | IteratorBase, @@ -212,7 +147,7 @@ def make_three_way_partition( first_op_adapter = make_op_adapter(select_first_part_op) second_op_adapter = make_op_adapter(select_second_part_op) - return _make_three_way_partition_cached( + return _ThreeWayPartition( d_in, d_first_part_out, d_second_part_out, diff --git a/python/cuda_cccl/cuda/compute/algorithms/_transform.py b/python/cuda_cccl/cuda/compute/algorithms/_transform.py index 82abe8a7fae..83843c00f12 100644 --- a/python/cuda_cccl/cuda/compute/algorithms/_transform.py +++ b/python/cuda_cccl/cuda/compute/algorithms/_transform.py @@ -7,7 +7,7 @@ from .. import _bindings from .. import _cccl_interop as cccl -from .._caching import cache_with_key +from .._caching import cache_with_registered_key_functions from .._cccl_interop import set_cccl_iterator_state from .._utils import protocols from ..iterators._iterators import IteratorBase @@ -119,59 +119,7 @@ def __call__( return None -def _make_unary_transform_cache_key( - d_in: DeviceArrayLike | IteratorBase, - d_out: DeviceArrayLike | IteratorBase, - op: OpAdapter, -): - d_in_key = ( - d_in.kind if isinstance(d_in, IteratorBase) else protocols.get_dtype(d_in) - ) - d_out_key = ( - d_out.kind if isinstance(d_out, IteratorBase) else protocols.get_dtype(d_out) - ) - return (d_in_key, d_out_key, op.get_cache_key()) - - -def _make_binary_transform_cache_key( - d_in1: DeviceArrayLike | IteratorBase, - d_in2: DeviceArrayLike | IteratorBase, - d_out: DeviceArrayLike | IteratorBase, - op: OpAdapter, -): - d_in1_key = ( - d_in1.kind if isinstance(d_in1, IteratorBase) else protocols.get_dtype(d_in1) - ) - d_in2_key = ( - d_in2.kind if isinstance(d_in2, IteratorBase) else protocols.get_dtype(d_in2) - ) - d_out_key = ( - d_out.kind if isinstance(d_out, IteratorBase) else protocols.get_dtype(d_out) - ) - return (d_in1_key, d_in2_key, d_out_key, op.get_cache_key()) - - -@cache_with_key(_make_unary_transform_cache_key) -def _make_unary_transform_cached( - d_in: DeviceArrayLike | IteratorBase, - d_out: DeviceArrayLike | IteratorBase, - op: OpAdapter, -): - """Internal cached factory for _UnaryTransform.""" - return _UnaryTransform(d_in, d_out, op) - - -@cache_with_key(_make_binary_transform_cache_key) -def _make_binary_transform_cached( - d_in1: DeviceArrayLike | IteratorBase, - d_in2: DeviceArrayLike | IteratorBase, - d_out: DeviceArrayLike | IteratorBase, - op: OpAdapter, -): - """Internal cached factory for _BinaryTransform.""" - return _BinaryTransform(d_in1, d_in2, d_out, op) - - +@cache_with_registered_key_functions def make_unary_transform( d_in: DeviceArrayLike | IteratorBase, d_out: DeviceArrayLike | IteratorBase, @@ -199,9 +147,10 @@ def make_unary_transform( A callable object that performs the transformation. """ op_adapter = make_op_adapter(op) - return _make_unary_transform_cached(d_in, d_out, op_adapter) + return _UnaryTransform(d_in, d_out, op_adapter) +@cache_with_registered_key_functions def make_binary_transform( d_in1: DeviceArrayLike | IteratorBase, d_in2: DeviceArrayLike | IteratorBase, @@ -231,7 +180,7 @@ def make_binary_transform( A callable object that performs the transformation. """ op_adapter = make_op_adapter(op) - return _make_binary_transform_cached(d_in1, d_in2, d_out, op_adapter) + return _BinaryTransform(d_in1, d_in2, d_out, op_adapter) def unary_transform( diff --git a/python/cuda_cccl/cuda/compute/algorithms/_unique_by_key.py b/python/cuda_cccl/cuda/compute/algorithms/_unique_by_key.py index 8f9c014feb0..c6966e0b99b 100644 --- a/python/cuda_cccl/cuda/compute/algorithms/_unique_by_key.py +++ b/python/cuda_cccl/cuda/compute/algorithms/_unique_by_key.py @@ -9,9 +9,8 @@ from .. import _bindings from .. import _cccl_interop as cccl -from .._caching import cache_with_key +from .._caching import cache_with_registered_key_functions from .._cccl_interop import call_build, set_cccl_iterator_state -from .._utils import protocols from .._utils.protocols import ( get_data_pointer, validate_and_get_stream, @@ -22,46 +21,6 @@ from ..typing import DeviceArrayLike -def _make_cache_key( - d_in_keys: DeviceArrayLike | IteratorBase, - d_in_items: DeviceArrayLike | IteratorBase, - d_out_keys: DeviceArrayLike | IteratorBase, - d_out_items: DeviceArrayLike | IteratorBase, - d_out_num_selected: DeviceArrayLike, - op: OpAdapter, -): - d_in_keys_key = ( - d_in_keys.kind - if isinstance(d_in_keys, IteratorBase) - else protocols.get_dtype(d_in_keys) - ) - d_in_items_key = ( - d_in_items.kind - if isinstance(d_in_items, IteratorBase) - else protocols.get_dtype(d_in_items) - ) - d_out_keys_key = ( - d_out_keys.kind - if isinstance(d_out_keys, IteratorBase) - else protocols.get_dtype(d_out_keys) - ) - d_out_items_key = ( - d_out_items.kind - if isinstance(d_out_items, IteratorBase) - else protocols.get_dtype(d_out_items) - ) - d_out_num_selected_key = protocols.get_dtype(d_out_num_selected) - - return ( - d_in_keys_key, - d_in_items_key, - d_out_keys_key, - d_out_items_key, - d_out_num_selected_key, - op.get_cache_key(), - ) - - class _UniqueByKey: __slots__ = [ "build_result", @@ -145,21 +104,7 @@ def __call__( return temp_storage_bytes -@cache_with_key(_make_cache_key) -def _make_unique_by_key_cached( - d_in_keys: DeviceArrayLike | IteratorBase, - d_in_items: DeviceArrayLike | IteratorBase, - d_out_keys: DeviceArrayLike | IteratorBase, - d_out_items: DeviceArrayLike | IteratorBase, - d_out_num_selected: DeviceArrayLike, - op: OpAdapter, -): - """Internal cached factory for _UniqueByKey.""" - return _UniqueByKey( - d_in_keys, d_in_items, d_out_keys, d_out_items, d_out_num_selected, op - ) - - +@cache_with_registered_key_functions def make_unique_by_key( d_in_keys: DeviceArrayLike | IteratorBase, d_in_items: DeviceArrayLike | IteratorBase, @@ -190,13 +135,8 @@ def make_unique_by_key( A callable object that can be used to perform unique by key """ op_adapter = make_op_adapter(op) - return _make_unique_by_key_cached( - d_in_keys, - d_in_items, - d_out_keys, - d_out_items, - d_out_num_selected, - op_adapter, + return _UniqueByKey( + d_in_keys, d_in_items, d_out_keys, d_out_items, d_out_num_selected, op_adapter ) diff --git a/python/cuda_cccl/cuda/compute/iterators/_iterators.py b/python/cuda_cccl/cuda/compute/iterators/_iterators.py index 9a88255a8c9..96c220c0cc9 100644 --- a/python/cuda_cccl/cuda/compute/iterators/_iterators.py +++ b/python/cuda_cccl/cuda/compute/iterators/_iterators.py @@ -13,7 +13,7 @@ from numba.cuda.dispatcher import CUDADispatcher from .._bindings import IteratorState -from .._caching import CachableFunction +from .._caching import CachableFunction, cache_with_registered_key_functions from .._utils.protocols import ( compute_c_contiguous_strides_in_bytes, get_data_pointer, @@ -678,3 +678,6 @@ def _get_last_element_ptr(device_array) -> int: ptr = get_data_pointer(device_array) return ptr + offset_to_last_element + + +cache_with_registered_key_functions.register(IteratorBase, lambda it: it.kind) diff --git a/python/cuda_cccl/cuda/compute/iterators/_permutation_iterator.py b/python/cuda_cccl/cuda/compute/iterators/_permutation_iterator.py index 74479e28bf1..8da925cd9d3 100644 --- a/python/cuda_cccl/cuda/compute/iterators/_permutation_iterator.py +++ b/python/cuda_cccl/cuda/compute/iterators/_permutation_iterator.py @@ -10,7 +10,7 @@ from numba.core.datamodel.registry import default_manager # noqa: F401 from numba.core.extending import as_numba_type, intrinsic # noqa: F401 -from .._caching import cache_with_key +from .._caching import cache_with_registered_key_functions from .._cccl_interop import get_dtype from ..struct import make_struct_type from ._iterators import ( @@ -24,12 +24,7 @@ class PermutationIteratorKind(IteratorKind): pass -def _make_cache_key(values, indices): - """Create a cache key based on value type and iterator kinds.""" - return (values.value_type, values.kind, indices.kind) - - -@cache_with_key(_make_cache_key) +@cache_with_registered_key_functions def _generate_advance_and_dereference_methods(values, indices): values_state_type = values.state_type index_type = indices.value_type diff --git a/python/cuda_cccl/cuda/compute/iterators/_zip_iterator.py b/python/cuda_cccl/cuda/compute/iterators/_zip_iterator.py index 0d8807e7333..d6ff73ebefb 100644 --- a/python/cuda_cccl/cuda/compute/iterators/_zip_iterator.py +++ b/python/cuda_cccl/cuda/compute/iterators/_zip_iterator.py @@ -9,7 +9,7 @@ from numba.core.datamodel.registry import default_manager # noqa: F401 from numba.core.extending import as_numba_type, intrinsic # noqa: F401 -from .._caching import cache_with_key +from .._caching import cache_with_registered_key_functions from ..struct import make_struct_type from ._iterators import ( IteratorBase, @@ -50,13 +50,8 @@ def _get_zip_iterator_metadata(iterators): return cvalue, state_type, value_type, ZipValue -def _make_cache_key(iterators): - return tuple( - it.kind if isinstance(it, IteratorBase) else it.dtype for it in iterators - ) - - -@cache_with_key(_make_cache_key) +# Automatic: iterators tuple → (kind, kind, ...) or (dtype, dtype, ...) +@cache_with_registered_key_functions def _get_advance_and_dereference_functions(iterators): # Generate the advance and dereference functions for the zip iterator # composed of the input iterators diff --git a/python/cuda_cccl/cuda/compute/op.py b/python/cuda_cccl/cuda/compute/op.py index ca534abf4a7..5a8e8f4b5bb 100644 --- a/python/cuda_cccl/cuda/compute/op.py +++ b/python/cuda_cccl/cuda/compute/op.py @@ -2,10 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Callable, Hashable +from typing import Callable from ._bindings import Op, OpKind -from ._caching import CachableFunction +from ._caching import CachableFunction, cache_with_registered_key_functions def _is_well_known_op(op: OpKind) -> bool: @@ -19,10 +19,6 @@ class _OpAdapter: - Stateless user-provided callables """ - def get_cache_key(self) -> Hashable: - """Return a hashable cache key for this operator.""" - raise NotImplementedError("Subclasses must implement this method") - def compile(self, input_types, output_type=None) -> Op: """ Compile this operator to an Op for CCCL interop. @@ -55,9 +51,6 @@ def __init__(self, kind: OpKind): ) self._kind = kind - def get_cache_key(self) -> Hashable: - return (self._kind.name, self._kind.value) - def compile(self, input_types, output_type=None) -> Op: return Op( operator_type=self._kind, @@ -82,9 +75,6 @@ def __init__(self, func: Callable): self._func = func self._cachable = CachableFunction(func) - def get_cache_key(self) -> Hashable: - return self._cachable - def compile(self, input_types, output_type=None) -> Op: from . import _cccl_interop as cccl from .numba_utils import get_inferred_return_type, signature_from_annotations @@ -138,3 +128,22 @@ def make_op_adapter(op) -> OpAdapter: "OpKind", "make_op_adapter", ] + + +# ============================================================================ +# Register key functions +# ============================================================================ + +cache_with_registered_key_functions.register( + _WellKnownOp, lambda op: (op._kind.name, op._kind.value) +) + +cache_with_registered_key_functions.register(_StatelessOp, lambda op: op._cachable) + +cache_with_registered_key_functions.register( + OpKind, lambda kind: (kind.name, kind.value) +) + +cache_with_registered_key_functions.register( + type(lambda: None), lambda func: CachableFunction(func) +) diff --git a/python/cuda_cccl/cuda/compute/typing.py b/python/cuda_cccl/cuda/compute/typing.py index 897b530e5b5..3ba7cba93cd 100644 --- a/python/cuda_cccl/cuda/compute/typing.py +++ b/python/cuda_cccl/cuda/compute/typing.py @@ -26,6 +26,7 @@ class StreamLike(Protocol): def __cuda_stream__(self) -> tuple[int, int]: ... +@runtime_checkable class GpuStruct(Protocol): """ Type of instances of structs created with gpu_struct().