Skip to content

cuda.compute: Consolidate caching logic across all algorithms#7281

Merged
shwina merged 2 commits intoNVIDIA:mainfrom
shwina:consolidate-caching
Jan 21, 2026
Merged

cuda.compute: Consolidate caching logic across all algorithms#7281
shwina merged 2 commits intoNVIDIA:mainfrom
shwina:consolidate-caching

Conversation

@shwina
Copy link
Contributor

@shwina shwina commented Jan 18, 2026

Description

This PR is purely a refactor.

In #6938 (comment), I promised to simplify how we do caching. This PR consolidates all of the caching logic by introducing a registry that maps types to key functions (defined in _caching.py). Thus, we don't have to write a cache key function for every individual algorithm (which was incredibly repetitive).

As a nice side-effect, this also removes the indirection from the referenced comment.

Checklist

  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

@shwina shwina requested a review from a team as a code owner January 18, 2026 11:45
@shwina shwina requested a review from NaderAlAwar January 18, 2026 11:45
@github-project-automation github-project-automation bot moved this to Todo in CCCL Jan 18, 2026
@cccl-authenticator-app cccl-authenticator-app bot moved this from Todo to In Review in CCCL Jan 18, 2026
@shwina shwina force-pushed the consolidate-caching branch from 836bf47 to 62af0fe Compare January 18, 2026 11:46
@github-actions

This comment has been minimized.

def get_cache_key(self) -> Hashable:
"""Return a hashable cache key for this operator."""
raise NotImplementedError("Subclasses must implement this method")
Note: Cache key extraction is handled by the cache key registry,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just delete this mehtod

@shwina shwina force-pushed the consolidate-caching branch from 62af0fe to 2206921 Compare January 19, 2026 15:25
@github-actions

This comment has been minimized.

@shwina shwina force-pushed the consolidate-caching branch from 2206921 to e97de0b Compare January 20, 2026 10:29
@github-actions

This comment has been minimized.

@shwina shwina self-assigned this Jan 20, 2026
"""
Decorator to cache the result of the decorated function.

If `key` is provided, it should be a function that computes the cache key
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If `key` is provided, it should be a function that computes the cache key
If `func_or_key_function` is provided, it should be a function that computes the cache key

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We take no arguments in the decorator now, so this has gone away.

Decorator to cache the result of the decorated function.

If `key` is provided, it should be a function that computes the cache key
from the function arguments. Otherwise, the cache key is automatically
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from the function arguments. Otherwise, the cache key is automatically
from the decorated function's arguments. Otherwise, the cache key is automatically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done (clarified whose arguments in the docstring).

Args:
func_or_key_function: The function to decorate, or a custom key function.
"""
# Check if this is being used without parentheses (@cache_with_registered_key_functions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: if func_or_key_function is always supposed to be a callable, can we just add a type check at the start and fail fast if the passed value is not a callable? Then remove the checks for callable later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't take func_or_key_function as an argument anymore (just func), this is a non-issue

"""
_KEY_FUNCTIONS[type_] = key_function

def __call__(self, func_or_key_function=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: add a type annotation to func_or_key_function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto as above.

Comment on lines 51 to 52
if isinstance(value, np.dtype):
return value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: let us register a key function for np.dtype (if we don't already)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 59 to 74
if hasattr(value, "__cuda_array_interface__"):
# It's a device array-like, extract dtype and fully-qualified type name
from ._utils.protocols import get_dtype

def cache_clear():
cache.clear()
# Use fully-qualified type name (e.g., 'numpy.ndarray' v/s 'cupy.ndarray')
type_fqn = f"{type(value).__module__}.{type(value).__name__}"
return (type_fqn, get_dtype(value))

inner.cache_clear = cache_clear
# Check if it has a dtype attribute
if hasattr(value, "dtype"):
type_fqn = f"{type(value).__module__}.{type(value).__name__}"
return (type_fqn, value.dtype)

# Register the cache in the central registry
cache_name = func.__qualname__
_cache_registry[cache_name] = inner
# For callables, wrap in CachableFunction
if callable(value):
return CachableFunction(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: register key functions instead of implementing logic here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


# Register built-in types
# Include fully-qualified type name to distinguish np.ndarray from cp.ndarray from GpuStruct
_KEY_FUNCTIONS[np.ndarray] = lambda arr: ("numpy.ndarray", arr.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: use register() instead if possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 154 to 159
if callable(func_or_key_function) and not hasattr(
func_or_key_function, "__self__"
):
# Direct decoration: @cache_with_registered_key_functions
# func_or_key is the actual function being decorated
return self._make_wrapper(None)(func_or_key_function)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: let's not accept custom key functions for now since none of the existing algorithms need it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 210 to 211
# Keep old name for backwards compatibility
cache_with_key = cache_with_registered_key_functions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: remove this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@shwina shwina force-pushed the consolidate-caching branch from e97de0b to 62b81d4 Compare January 20, 2026 21:56
@github-actions

This comment has been minimized.

@github-actions
Copy link
Contributor

🥳 CI Workflow Results

🟩 Finished in 13h 38m: Pass: 100%/48 | Total: 15h 19m | Max: 59m 28s

See results here.

@shwina shwina requested a review from NaderAlAwar January 21, 2026 13:57
Copy link
Contributor

@NaderAlAwar NaderAlAwar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

@shwina shwina merged commit e9f0971 into NVIDIA:main Jan 21, 2026
123 of 126 checks passed
@github-project-automation github-project-automation bot moved this from In Review to Done in CCCL Jan 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

2 participants