cuda.compute: Consolidate caching logic across all algorithms#7281
cuda.compute: Consolidate caching logic across all algorithms#7281shwina merged 2 commits intoNVIDIA:mainfrom
Conversation
836bf47 to
62af0fe
Compare
This comment has been minimized.
This comment has been minimized.
python/cuda_cccl/cuda/compute/op.py
Outdated
| def get_cache_key(self) -> Hashable: | ||
| """Return a hashable cache key for this operator.""" | ||
| raise NotImplementedError("Subclasses must implement this method") | ||
| Note: Cache key extraction is handled by the cache key registry, |
There was a problem hiding this comment.
Just delete this mehtod
62af0fe to
2206921
Compare
This comment has been minimized.
This comment has been minimized.
2206921 to
e97de0b
Compare
This comment has been minimized.
This comment has been minimized.
| """ | ||
| Decorator to cache the result of the decorated function. | ||
|
|
||
| If `key` is provided, it should be a function that computes the cache key |
There was a problem hiding this comment.
| If `key` is provided, it should be a function that computes the cache key | |
| If `func_or_key_function` is provided, it should be a function that computes the cache key |
There was a problem hiding this comment.
We take no arguments in the decorator now, so this has gone away.
| Decorator to cache the result of the decorated function. | ||
|
|
||
| If `key` is provided, it should be a function that computes the cache key | ||
| from the function arguments. Otherwise, the cache key is automatically |
There was a problem hiding this comment.
| from the function arguments. Otherwise, the cache key is automatically | |
| from the decorated function's arguments. Otherwise, the cache key is automatically |
There was a problem hiding this comment.
Done (clarified whose arguments in the docstring).
| Args: | ||
| func_or_key_function: The function to decorate, or a custom key function. | ||
| """ | ||
| # Check if this is being used without parentheses (@cache_with_registered_key_functions) |
There was a problem hiding this comment.
Suggestion: if func_or_key_function is always supposed to be a callable, can we just add a type check at the start and fail fast if the passed value is not a callable? Then remove the checks for callable later
There was a problem hiding this comment.
Since we don't take func_or_key_function as an argument anymore (just func), this is a non-issue
| """ | ||
| _KEY_FUNCTIONS[type_] = key_function | ||
|
|
||
| def __call__(self, func_or_key_function=None): |
There was a problem hiding this comment.
Suggestion: add a type annotation to func_or_key_function
| if isinstance(value, np.dtype): | ||
| return value |
There was a problem hiding this comment.
Suggestion: let us register a key function for np.dtype (if we don't already)
| if hasattr(value, "__cuda_array_interface__"): | ||
| # It's a device array-like, extract dtype and fully-qualified type name | ||
| from ._utils.protocols import get_dtype | ||
|
|
||
| def cache_clear(): | ||
| cache.clear() | ||
| # Use fully-qualified type name (e.g., 'numpy.ndarray' v/s 'cupy.ndarray') | ||
| type_fqn = f"{type(value).__module__}.{type(value).__name__}" | ||
| return (type_fqn, get_dtype(value)) | ||
|
|
||
| inner.cache_clear = cache_clear | ||
| # Check if it has a dtype attribute | ||
| if hasattr(value, "dtype"): | ||
| type_fqn = f"{type(value).__module__}.{type(value).__name__}" | ||
| return (type_fqn, value.dtype) | ||
|
|
||
| # Register the cache in the central registry | ||
| cache_name = func.__qualname__ | ||
| _cache_registry[cache_name] = inner | ||
| # For callables, wrap in CachableFunction | ||
| if callable(value): | ||
| return CachableFunction(value) |
There was a problem hiding this comment.
Suggestion: register key functions instead of implementing logic here
|
|
||
| # Register built-in types | ||
| # Include fully-qualified type name to distinguish np.ndarray from cp.ndarray from GpuStruct | ||
| _KEY_FUNCTIONS[np.ndarray] = lambda arr: ("numpy.ndarray", arr.dtype) |
There was a problem hiding this comment.
Suggestion: use register() instead if possible
| if callable(func_or_key_function) and not hasattr( | ||
| func_or_key_function, "__self__" | ||
| ): | ||
| # Direct decoration: @cache_with_registered_key_functions | ||
| # func_or_key is the actual function being decorated | ||
| return self._make_wrapper(None)(func_or_key_function) |
There was a problem hiding this comment.
Suggestion: let's not accept custom key functions for now since none of the existing algorithms need it
| # Keep old name for backwards compatibility | ||
| cache_with_key = cache_with_registered_key_functions |
There was a problem hiding this comment.
Suggestion: remove this
e97de0b to
62b81d4
Compare
This comment has been minimized.
This comment has been minimized.
🥳 CI Workflow Results🟩 Finished in 13h 38m: Pass: 100%/48 | Total: 15h 19m | Max: 59m 28sSee results here. |
Description
This PR is purely a refactor.
In #6938 (comment), I promised to simplify how we do caching. This PR consolidates all of the caching logic by introducing a registry that maps types to key functions (defined in
_caching.py). Thus, we don't have to write a cache key function for every individual algorithm (which was incredibly repetitive).As a nice side-effect, this also removes the indirection from the referenced comment.
Checklist