Skip to content

Enable multi-partition Select operations containing basic aggregations#17941

Closed
rjzamora wants to merge 70 commits intorapidsai:branch-25.06from
rjzamora:complex-aggregations
Closed

Enable multi-partition Select operations containing basic aggregations#17941
rjzamora wants to merge 70 commits intorapidsai:branch-25.06from
rjzamora:complex-aggregations

Conversation

@rjzamora
Copy link
Member

@rjzamora rjzamora commented Feb 6, 2025

Description

The overall goal is to enable us to decompose arbitrary Expr graphs containing one or more "non-pointwise" nodes. In order to achieve this, I propose that we add an experimental FusedExpr class (and related Expr-graph decomposition utilities). The general idea is that we can iteratively traverse an Expr-graph in reverse-topological order, and rewrite the graph until it is entirely composed of FusedExpr nodes. From there, it becomes relatively simple to build the task graph for each FusedExpr node independently.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

@rjzamora rjzamora added feature request New feature or request 2 - In Progress Currently a work in progress non-breaking Non-breaking change cudf-polars Issues specific to cudf-polars labels Feb 6, 2025
@rjzamora rjzamora self-assigned this Feb 6, 2025
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 6, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@github-actions github-actions bot added the Python Affects Python cuDF API. label Feb 6, 2025
@rjzamora
Copy link
Member Author

rjzamora commented Feb 6, 2025

Note: The rewrite approach used in this PR is illustrated in the figure below.

Aggregations drawio

Copy link
Contributor

@wence- wence- left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partial comments

Comment on lines 145 to 150
def rename_agg(agg: Agg, new_name: str):
"""Modify the name of an aggregation expression."""
return CachingVisitor(
replace_sub_expr,
state={"mapping": {agg: Agg(agg.dtype, new_name, agg.options, *agg.children)}},
)(agg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Renaming" feels like the wrong thing, because the options for one agg might not apply to the options of another.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the options for one agg might not apply to the options of another

Yeah, that's totally true. We definitely need to figure out how the options "plumbing" should work here. I was hoping the options would normally translate in a trivial way, but I don't have a great sense for the range of possibilities.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we just need pattern rules that map a global Agg into a pair of (local_aggs, finalise_agg) or whatever it looks like.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for this iteration we can just pass in None for the "new" options by default (since that's actually the "correct" thing to do for the supported aggregations that require "renaming").

rapids-bot bot pushed a commit that referenced this pull request Feb 13, 2025
It will be useful to serialize individual columns during multi-GPU cudf-polars execution. For example, the `Expr`-decomposition approach proposed in #17941 may "require" `Column` serialization (or an ugly workaround).

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - Matthew Murray (https://github.com/Matt711)
  - Lawrence Mitchell (https://github.com/wence-)

URL: #17990
children = [child for child in traversal([old]) if isinstance(child, FusedExpr)]
new = FusedExpr(old.dtype, old, *children)
mapper = CachingVisitor(replace_sub_expr, state={"mapping": {old: new}})
root = mapper(root)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it should be possible to write this function as a single bottom up visit rather than this approach which does a fixed-point iteration with three full tree-traversals per iteration.

Can you describe what rewrite patterns at a node are?

I think it's something like:

def _decompose(expr: Expr, rec: ExprTransformer):
    new_children = tuple(map(rec, expr.children))
    fused_children = tuple(c for c in new_children if isinstance(c, FusedExpr))
    if fused_children:
        return FusedExpr(expr.dtype, expr, fused_children)
    elif not e.is_pointwise:
        # check for supported case of `Agg`...
        return FusedExpr(expr.dtype, expr, ())
    else:
        # pointwise, no fused children
        return expr

def decompose_expr_graph(expr):
   mapper = CachingVisitor(_decompose)
   return mapper(expr)

Do I have it right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, kind of. I played with this for a bit, and I think I have something working now.

@rjzamora
Copy link
Member Author

rjzamora commented Apr 8, 2025

Update: This PR now depends on #18405

Copy link
Contributor

@wence- wence- left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of questions/comments. I think this is pretty good but there are a few places where I'm confused about the logic

def collect_agg(self, *, depth: int) -> AggInfo: # pragma: no cover
"""Collect information about aggregations in groupbys."""
return self.sub_expr.collect_agg(depth=depth)
assert all(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be independent of the size of the graph, do you mean "large expressions"?

Comment on lines 116 to 117
child_ir_count
Partition count for the child-IR node.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean "child" here, or do you mean the partition count of the input frame?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use the term child-IR to mean the input dataframe in most places, since this IR node is the child of the overarching Select node (that owns this expression we are processing). This is indeed, not an expression child.

Would you prefer "input IR"? I just want to distinguish this input IR node from the IR node that "owns" the expression. Otherwise, I don't care much about the naming.

Comment on lines 135 to 136
if skip_fused_exprs:
continue # Stay within the current sub expression
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add:

diff --git a/python/cudf_polars/cudf_polars/dsl/traversal.py b/python/cudf_polars/cudf_polars/dsl/traversal.py
index 9c45a68812..50a643e2cc 100644
--- a/python/cudf_polars/cudf_polars/dsl/traversal.py
+++ b/python/cudf_polars/cudf_polars/dsl/traversal.py
@@ -49,6 +49,39 @@ def traversal(nodes: Sequence[NodeT]) -> Generator[NodeT, None, None]:
                 lifo.append(child)
 
 
+def cutoff_traversal(
+    nodes: Sequence[NodeT], *, cutoff_types: tuple[type[NodeT], ...]
+) -> Generator[NodeT, None, None]:
+    """
+    Pre-order traversal of nodes in an expression.
+
+    Parameters
+    ----------
+    nodes
+        Roots of expressions to traverse.
+    cutoff_types
+        Types to terminate traversal at. If a type is in this tuple
+        then we do not yield any of its children.
+
+    Yields
+    ------
+    Unique nodes in the expressions, parent before child, children
+    in-order from left to right.
+    """
+    seen = set(nodes)
+    lifo = list(nodes)
+
+    while lifo:
+        node = lifo.pop()
+        yield node
+        if isinstance(node, cutoff_types):
+            continue
+        for child in reversed(node.children):
+            if child not in seen:
+                seen.add(child)
+                lifo.append(child)
+
+
 def reuse_if_unchanged(node: NodeT, fn: GenericTransformer[NodeT, NodeT]) -> NodeT:
     """
     Recipe for transforming nodes that returns the old object if unchanged.

and then use (at the top of this function):

traverse = partial(cutoff_traversal, cutoff_types=(FusedExpr,)) else traversal
...
for e in exprs:
   for node in list(traverse([expr]))[::-1]:
       ...

You won't needlessly traverse the full tree when skipping fused exprs in the children.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was tempted to do this before, but held off. Shall I just add an optional cutoff_types argument to traversal rather than duplicate most of the logic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think that's reasonable.

Comment on lines 86 to 88
e.is_pointwise or isinstance(e, FusedExpr)
for e in traversal(list(sub_expr.children))
), f"Invalid FusedExpr sub-expression: {sub_expr}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to validate this once before lowering to a graph. Rather than on construction. This is where it's "annoying" that FusedExprs have two types of children. The actual sub expressions and the FusedExpr "children".

)

pi: MutableMapping[IR, PartitionInfo] = {child: child_partition_info}
schema = {col.name: col.dtype for col in traversal([on]) if isinstance(col, Col)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very wrong, it's extracting all columns that are leaves of on (an arbitrary expression that we are shuffling on) and using those as the schema of the Select call, which only selects a single column (shuffle_on).

What are you trying to do here?

I suppose what you're trying to do is determine which columns of the child are required to evaluate shuffle_on?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can avoid all this for now if you'd rather keep things simple (and there could definitely be a mistake).

The general idea is that we need to shuffle the input IR node to evaluate this expression, but the output is only a single column. So, in order to avoid shuffling a bunch of columns that we don't actually need to evaluate the expression, we (1) look for the underlying columns needed by the expression, (2) drop all other columns from the input IR node, and (3) shuffle only the columns we need.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, but those columns need to be in the list of expressions you're selecting

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but those columns need to be in the list of expressions you're selecting

I don't think I understand this comment very well. Are you saying that we cannot drop columns because we might need it in the overarching Select, even if we don't need it for the current FusedExpr we are processing in this function?

When we drop columns and shuffle IR, the modified/shuffled IR is only used for this specific FusedExpr. If other FusedExpr nodes depended on other columns, they would still be referencing the original input IR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand it, you're shuffling on some expression A. This expression depends on some columns (B, C). The child input might have additional columns (D, E, ...). And you want to ensure that you only move around the data you need to evaluate A.

So I think this is:

necessary_columns = [e for e in traversal([on]) if isinstance(e, expr.Col)]

input = Select({c.name: c.dtype for c in necessary_columns},
               necessary_columns,
               False,
               child)

shuffled = Shuffle(..., shuffle_on, input)

No?

Comment on lines 334 to 339
child = Select(
schema,
shuffle_on,
False, # noqa: FBT003
child,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: it seems like this Select might induce a shuffle itself (if, for example, shuffle_on contains an n_unique aggregation, or later a sort). Will that break anything? Or is it guaranteed not to be the case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will that break anything? Or is it guaranteed not to be the case?

This is guaranteed not to be the case, because we have decomposed everything into a graph of FusedExpr nodes. A FusedExpr may only contain a single non-pointwise expression.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it is possible that two distinct FusedExpr nodes will end up shuffling the same input IR multiple times (on the same or different columns). I haven't thought through ways to optimize this case, but I don't think anything should "break".

if set(schema) != set(child.schema):
# Drop unnecessary columns before the shuffle
child = Select(
schema,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema of this select is {named_expr.name: named_expr.value.dtype}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but this is function is creating a new IR node that is shuffled (and that the original Select can be called on in a pointwise fashion). This temporary schema may have multiple columns (e.g. if we are selecting the uniques of some binary op between two columns).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that's not what this Select expression does: it produces a dataframe with a single output column that is the result of evaluating shuffle_on on the input.

@rjzamora
Copy link
Member Author

rjzamora commented Apr 9, 2025

Update: I decided that I don't like how much work we are doing during graph-construction time. I prefer handling the decomposition entirely during the lowering/re-write process. I'm exploring a slight refactoring of this PR in rjzamora:complex-aggregations-refactor.

It may take me a few more hours tomorrow morning to clean that up and get shuffling implemented for n_unique (It also seems fine to leave n_unique for a follow-up).

@rjzamora rjzamora marked this pull request as draft April 14, 2025 16:08
@rjzamora
Copy link
Member Author

Update: It is likely that #18492 will be superseding this PR. That PR accomplishes what we need without introducing any new (or Expr-specific) task-graph logic.

@rjzamora
Copy link
Member Author

Closing in favor of #18492

@rjzamora rjzamora closed this Apr 14, 2025
@github-project-automation github-project-automation bot moved this from In Progress to Done in cuDF Python Apr 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2 - In Progress Currently a work in progress cudf-polars Issues specific to cudf-polars feature request New feature or request non-breaking Non-breaking change Python Affects Python cuDF API.

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

3 participants