Skip to content

Comments

Unify operator handling in cuda.compute#6938

Merged
shwina merged 2 commits intomainfrom
op-adapter-refactor
Dec 10, 2025
Merged

Unify operator handling in cuda.compute#6938
shwina merged 2 commits intomainfrom
op-adapter-refactor

Conversation

@shwina
Copy link
Contributor

@shwina shwina commented Dec 10, 2025

Description

This PR is a refactor, in preparation for supporting stateful ops in cuda.compute.

The problem(s)

Duplicated logic for handling OpKind v/s callables

Currently, user-defined operations (e.g., predicates, transformations, etc.,) can be provided either as "built-in" ops (OpKind) or custom functions (callables). These differ in the way they are (1) cached and (2) "compiled" into LTOIR. To handle these differences, we currently have a bunch of if-else statements everywhere, like:

if isinstance(op, OpKind):
    self.op_wrapper = cccl.to_cccl_op(op, None)
else:
    self.op_wrapper = cccl.to_cccl_op(op, value_type(value_type, value_type))

and:

    # Handle well-known operations differently
    op_key: Union[tuple[str, int], CachableFunction]
    if isinstance(op, OpKind):
        op_key = (op.name, op.value)
    else:
        op_key = CachableFunction(op)

Determining signatures from annotations

If provided, type annotations offer a faster way to determine the return type of a user-defined callable, compared to using numba type inference. We take advantage of this in., e.g., transform, but it would be nice to do this for all ops. Ideally, we don't want to repeat the logic everywhere.

Solution

This PR solves the above by introducing an (internal) OpAdaptor type that encapsulates the logic for caching, signature determination, and compiling. Furthermore, it will make adding support for stateless ops much easier.

Checklist

  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

@copy-pr-bot
Copy link
Contributor

copy-pr-bot bot commented Dec 10, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@cccl-authenticator-app cccl-authenticator-app bot moved this from Todo to In Progress in CCCL Dec 10, 2025
@shwina
Copy link
Contributor Author

shwina commented Dec 10, 2025

/ok to test 2d3f647

2 similar comments
@shwina
Copy link
Contributor Author

shwina commented Dec 10, 2025

/ok to test 2d3f647

@shwina
Copy link
Contributor Author

shwina commented Dec 10, 2025

/ok to test 2d3f647

@shwina
Copy link
Contributor Author

shwina commented Dec 10, 2025

/ok to test 2d3f647


@cache_with_key(make_cache_key)
@cache_with_key(_make_cache_key)
def _make_merge_sort_cached(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewers: this approach does introduce a layer of indirection here in the caching.

I do have some ideas for unifying caching across all algorithms which should simplify this, but I'll do that in a subsequent PR.

@shwina shwina marked this pull request as ready for review December 10, 2025 20:11
@shwina shwina requested a review from a team as a code owner December 10, 2025 20:11
@shwina shwina requested a review from leofang December 10, 2025 20:11
@cccl-authenticator-app cccl-authenticator-app bot moved this from In Progress to In Review in CCCL Dec 10, 2025
Copy link
Contributor

@NaderAlAwar NaderAlAwar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, left a few comments. We should also be careful not to introduce too much overhead to the single phase API

self._kind = kind

def get_cache_key(self) -> Hashable:
return (self.__class__.__name__, self._kind.name, self._kind.value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: prior to this change, we only return (op.name, op.value). Why do we need to include self.__class__.__name__?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed - ditto below

self._cachable = CachableFunction(func)

def get_cache_key(self) -> Hashable:
return (self.__class__.__name__, self._cachable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as above, why do we need self.__class__.__name__?

self.d_in_items_cccl = cccl.to_cccl_input_iter(d_in_items)
self.d_out_keys_cccl = cccl.to_cccl_output_iter(d_out_keys)
self.d_out_items_cccl = cccl.to_cccl_output_iter(d_out_items)
self.op_adapter = op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: why do we store op_adapter as a member variable? This also applies to other algorithms

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we do introduce stateful operators, the op_adapter is what will hold the state arrays. That being said, let me remove this change from this PR and introduce it (or something else) in the subsequent PR.

self.d_in_cccl = cccl.to_cccl_input_iter(d_in)
self.d_out_cccl = cccl.to_cccl_output_iter(d_out)
self.h_init_cccl = cccl.to_cccl_value(h_init)
self.op = op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important: this is named op_adapter in merge_sort. We should use consistent names.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll use op everywhere.

d_out: DeviceArrayLike | IteratorBase,
d_num_selected_out: DeviceArrayLike,
cond: Callable,
cond: Callable | OpAdapter, # Raw callable or Operator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: it is not clear to me why this is annotated differently than the other algorithms? Also the comment seems unnecessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@github-actions

This comment has been minimized.

@shwina
Copy link
Contributor Author

shwina commented Dec 10, 2025

/ok to test 557b6a0

@shwina shwina enabled auto-merge (squash) December 10, 2025 21:49
@github-actions
Copy link
Contributor

🥳 CI Workflow Results

🟩 Finished in 1h 28m: Pass: 100%/48 | Total: 11h 31m | Max: 44m 25s

See results here.

@shwina shwina merged commit 01fa22a into main Dec 10, 2025
62 checks passed
@github-project-automation github-project-automation bot moved this from In Review to Done in CCCL Dec 10, 2025
@shwina shwina deleted the op-adapter-refactor branch December 11, 2025 09:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

2 participants