Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
72dddb5
Rewrite groupbys during IR translation
wence- Mar 14, 2025
82cba40
Add more tests of groupby aggregations
wence- Mar 17, 2025
43763d3
Type alias for Schema
wence- Mar 25, 2025
3f2609d
Handle decomposition of empty agg
wence- Mar 25, 2025
8396bc0
Test coverage of rolling not implemented
wence- Mar 25, 2025
aee38fe
More test coverage
wence- Mar 25, 2025
130c670
Filter expression is not pointwise
wence- Mar 26, 2025
246a046
Vertical string concat is not pointwise
wence- Mar 26, 2025
ef24aae
Don't support ternary in grouped aggregations
wence- Mar 26, 2025
1315a7b
Fix docstring
wence- Mar 26, 2025
95ac50b
Take names to be unique against in unique_names
wence- Mar 26, 2025
d26008b
Better fill null
wence- Mar 26, 2025
4acd0e7
Drop/fill nulls only if necessary
wence- Mar 26, 2025
ae03abf
More groupby tests
wence- Mar 26, 2025
cde3cbd
These tests now pass
wence- Mar 26, 2025
73a249f
Merge branch 'branch-25.06' into wence/fea/polars-rewrite-groupby
wence- Mar 31, 2025
efb628b
Merge branch 'branch-25.06' into wence/fea/polars-rewrite-groupby
wence- Apr 2, 2025
d093128
Merge branch 'branch-25.06' into wence/fea/polars-rewrite-groupby
wence- Apr 9, 2025
c99d4ad
Merge remote-tracking branch 'upstream/branch-25.06' into wence/fea/p…
wence- Apr 25, 2025
d34d2d0
Merge remote-tracking branch 'upstream/branch-25.06' into wence/fea/p…
wence- Apr 28, 2025
60d2bed
Better docstrings
wence- Apr 29, 2025
5f55c71
Merge remote-tracking branch 'upstream/branch-25.06' into wence/fea/p…
wence- Apr 29, 2025
4363793
Use NamedExpr.reconstruct
wence- Apr 29, 2025
e184f42
Reinstate test (now passing)
wence- Apr 29, 2025
6203713
Expunge mapping argument from expression evaluation
wence- Apr 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 10 additions & 40 deletions python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,10 @@
import pylibcudf as plc

from cudf_polars.containers import Column
from cudf_polars.dsl.expressions.base import (
AggInfo,
ExecutionContext,
Expr,
)
from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
from cudf_polars.dsl.expressions.literal import Literal
from cudf_polars.dsl.expressions.unary import UnaryFunction

if TYPE_CHECKING:
from collections.abc import Mapping

from cudf_polars.containers import DataFrame

__all__ = ["Agg"]
Expand Down Expand Up @@ -122,38 +115,19 @@ def __init__(
"linear": plc.types.Interpolation.LINEAR,
}

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
if depth >= 1:
raise NotImplementedError(
"Nested aggregations in groupby"
) # pragma: no cover; check_agg trips first
if (isminmax := self.name in {"min", "max"}) and self.options:
raise NotImplementedError("Nan propagation in groupby for min/max")
(child,) = self.children
((expr, _, _),) = child.collect_agg(depth=depth + 1).requests
request = self.request
# These are handled specially here because we don't set up the
# request for the whole-frame agg because we can avoid a
# reduce for these.
@property
def agg_request(self) -> plc.aggregation.Aggregation: # noqa: D102
if self.name == "first":
request = plc.aggregation.nth_element(
return plc.aggregation.nth_element(
0, null_handling=plc.types.NullPolicy.INCLUDE
)
elif self.name == "last":
request = plc.aggregation.nth_element(
return plc.aggregation.nth_element(
-1, null_handling=plc.types.NullPolicy.INCLUDE
)
if request is None:
raise NotImplementedError(
f"Aggregation {self.name} in groupby"
) # pragma: no cover; __init__ trips first
if isminmax and plc.traits.is_floating_point(self.dtype):
assert expr is not None
# Ignore nans in these groupby aggs, do this by masking
# nans in the input
expr = UnaryFunction(self.dtype, "mask_nans", (), expr)
return AggInfo([(expr, request, self)])
else:
assert self.request is not None, "Init should have raised"
return self.request

def _reduce(
self, column: Column, *, request: plc.aggregation.Aggregation
Expand Down Expand Up @@ -215,11 +189,7 @@ def _last(self, column: Column) -> Column:
return Column(plc.copying.slice(column.obj, [n - 1, n])[0])

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
) -> Column:
"""Evaluate this expression given a dataframe for context."""
if context is not ExecutionContext.FRAME:
Expand All @@ -230,4 +200,4 @@ def do_evaluate(
# Aggregations like quantiles may have additional children that were
# preprocessed into pylibcudf requests.
child = self.children[0]
return self.op(child.evaluate(df, context=context, mapping=mapping))
return self.op(child.evaluate(df, context=context))
94 changes: 21 additions & 73 deletions python/cudf_polars/cudf_polars/dsl/expressions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from cudf_polars.dsl.nodebase import Node

if TYPE_CHECKING:
from collections.abc import Mapping

from typing_extensions import Self

from cudf_polars.containers import Column, DataFrame
Expand Down Expand Up @@ -48,11 +46,7 @@ class Expr(Node["Expr"]):
"""Names of non-child data (not Exprs) for reconstruction."""

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
) -> Column:
"""
Evaluate this expression given a dataframe for context.
Expand All @@ -63,15 +57,10 @@ def do_evaluate(
DataFrame that will provide columns.
context
What context are we performing this evaluation in?
mapping
Substitution mapping from expressions to Columns, used to
override the evaluation of a given expression if we're
performing a simple rewritten evaluation.

Notes
-----
Do not call this function directly, but rather
:meth:`evaluate` which handles the mapping lookups.
Do not call this function directly, but rather :meth:`evaluate`.

Returns
-------
Expand All @@ -89,11 +78,7 @@ def do_evaluate(
) # pragma: no cover; translation of unimplemented nodes trips first

def evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
) -> Column:
"""
Evaluate this expression given a dataframe for context.
Expand All @@ -104,10 +89,6 @@ def evaluate(
DataFrame that will provide columns.
context
What context are we performing this evaluation in?
mapping
Substitution mapping from expressions to Columns, used to
override the evaluation of a given expression if we're
performing a simple rewritten evaluation.

Notes
-----
Expand All @@ -126,37 +107,28 @@ def evaluate(
are returned during translation to the IR, but for now we
are not perfect.
"""
if mapping is None:
return self.do_evaluate(df, context=context, mapping=mapping)
try:
return mapping[self]
except KeyError:
return self.do_evaluate(df, context=context, mapping=mapping)

def collect_agg(self, *, depth: int) -> AggInfo:
"""
Collect information about aggregations in groupbys.
return self.do_evaluate(df, context=context)

Parameters
----------
depth
The depth of aggregating (reduction or sampling)
expressions we are currently at.
@property
def agg_request(self) -> plc.aggregation.Aggregation:
"""
The aggregation for this expression in a grouped aggregation.

Returns
-------
Aggregation info describing the expression to aggregate in the
groupby.
Aggregation request. Default is to collect the expression.

Notes
-----
This presumes that the IR translation has decomposed groupby
reductions only into cases we can handle.

Raises
------
NotImplementedError
If we can't currently perform the aggregation request, for
example nested aggregations like ``a.max().min()``.
If requesting an aggregation from an unexpected expression.
"""
raise NotImplementedError(
f"Collecting aggregation info for {type(self).__name__}"
) # pragma: no cover; check_agg trips first
return plc.aggregation.collect_list()


class ErrorExpr(Expr):
Expand All @@ -168,7 +140,7 @@ def __init__(self, dtype: plc.DataType, error: str) -> None:
self.dtype = dtype
self.error = error
self.children = ()
self.is_pointwise = True
self.is_pointwise = False


class NamedExpr:
Expand Down Expand Up @@ -204,11 +176,7 @@ def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

def evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
) -> Column:
"""
Evaluate this expression given a dataframe for context.
Expand All @@ -219,8 +187,6 @@ def evaluate(
DataFrame providing context
context
Execution context
mapping
Substitution mapping

Returns
-------
Expand All @@ -231,13 +197,7 @@ def evaluate(
:meth:`Expr.evaluate` for details, this function just adds the
name to a column produced from an expression.
"""
return self.value.evaluate(df, context=context, mapping=mapping).rename(
self.name
)

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
return self.value.collect_agg(depth=depth)
return self.value.evaluate(df, context=context).rename(self.name)

def reconstruct(self, expr: Expr) -> Self:
"""
Expand Down Expand Up @@ -270,21 +230,13 @@ def __init__(self, dtype: plc.DataType, name: str) -> None:
self.children = ()

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
) -> Column:
"""Evaluate this expression given a dataframe for context."""
# Deliberately remove the name here so that we guarantee
# evaluation of the IR produces names.
return df.column_map[self.name].rename(None)

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
return AggInfo([(self, plc.aggregation.collect_list(), self)])


class ColRef(Expr):
__slots__ = ("index", "table_ref")
Expand All @@ -308,11 +260,7 @@ def __init__(
self.children = (column,)

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
) -> Column:
"""Evaluate this expression given a dataframe for context."""
raise NotImplementedError(
Expand Down
42 changes: 3 additions & 39 deletions python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
import pylibcudf as plc

from cudf_polars.containers import Column
from cudf_polars.dsl.expressions.base import AggInfo, ExecutionContext, Expr
from cudf_polars.dsl.expressions.base import ExecutionContext, Expr

if TYPE_CHECKING:
from collections.abc import Mapping

from cudf_polars.containers import DataFrame

__all__ = ["BinOp"]
Expand Down Expand Up @@ -85,17 +83,10 @@ def __init__(
}

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
) -> Column:
"""Evaluate this expression given a dataframe for context."""
left, right = (
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
)
left, right = (child.evaluate(df, context=context) for child in self.children)
lop = left.obj
rop = right.obj
if left.size != right.size:
Expand All @@ -106,30 +97,3 @@ def do_evaluate(
return Column(
plc.binaryop.binary_operation(lop, rop, self.op, self.dtype),
)

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
if depth == 1:
# inside aggregation, need to pre-evaluate,
# groupby construction has checked that we don't have
# nested aggs, so stop the recursion and return ourselves
# for pre-eval
return AggInfo([(self, plc.aggregation.collect_list(), self)])
else:
left_info, right_info = (
child.collect_agg(depth=depth) for child in self.children
)
requests = [*left_info.requests, *right_info.requests]
# TODO: Hack, if there were no reductions inside this
# binary expression then we want to pre-evaluate and
# collect ourselves. Otherwise we want to collect the
# aggregations inside and post-evaluate. This is a bad way
# of checking that we are in case 1.
if all(
agg.kind() == plc.aggregation.Kind.COLLECT_LIST
for _, agg, _ in requests
):
return AggInfo([(self, plc.aggregation.collect_list(), self)])
return AggInfo(
[*left_info.requests, *right_info.requests],
)
15 changes: 3 additions & 12 deletions python/cudf_polars/cudf_polars/dsl/expressions/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
)

if TYPE_CHECKING:
from collections.abc import Mapping

from typing_extensions import Self

import polars.type_aliases as pl_types
Expand Down Expand Up @@ -145,11 +143,7 @@ def _distinct(
}

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
) -> Column:
"""Evaluate this expression given a dataframe for context."""
if self.name in (
Expand All @@ -162,7 +156,7 @@ def do_evaluate(
if child.dtype.id() not in (plc.TypeId.FLOAT32, plc.TypeId.FLOAT64):
value = plc.Scalar.from_py(is_finite)
return Column(plc.Column.from_scalar(value, df.num_rows))
needles = child.evaluate(df, context=context, mapping=mapping)
needles = child.evaluate(df, context=context)
to_search = [-float("inf"), float("inf")]
if is_finite:
# NaN is neither finite not infinite
Expand All @@ -177,10 +171,7 @@ def do_evaluate(
if is_finite:
result = plc.unary.unary_operation(result, plc.unary.UnaryOperator.NOT)
return Column(result)
columns = [
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
]
columns = [child.evaluate(df, context=context) for child in self.children]
# Kleene logic for Any (OR) and All (AND) if ignore_nulls is
# False
if self.name in (BooleanFunction.Name.Any, BooleanFunction.Name.All):
Expand Down
Loading
Loading