diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py b/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py index 29562aee4c7..4763b3cd823 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py @@ -12,17 +12,10 @@ import pylibcudf as plc from cudf_polars.containers import Column -from cudf_polars.dsl.expressions.base import ( - AggInfo, - ExecutionContext, - Expr, -) +from cudf_polars.dsl.expressions.base import ExecutionContext, Expr from cudf_polars.dsl.expressions.literal import Literal -from cudf_polars.dsl.expressions.unary import UnaryFunction if TYPE_CHECKING: - from collections.abc import Mapping - from cudf_polars.containers import DataFrame __all__ = ["Agg"] @@ -122,38 +115,19 @@ def __init__( "linear": plc.types.Interpolation.LINEAR, } - def collect_agg(self, *, depth: int) -> AggInfo: - """Collect information about aggregations in groupbys.""" - if depth >= 1: - raise NotImplementedError( - "Nested aggregations in groupby" - ) # pragma: no cover; check_agg trips first - if (isminmax := self.name in {"min", "max"}) and self.options: - raise NotImplementedError("Nan propagation in groupby for min/max") - (child,) = self.children - ((expr, _, _),) = child.collect_agg(depth=depth + 1).requests - request = self.request - # These are handled specially here because we don't set up the - # request for the whole-frame agg because we can avoid a - # reduce for these. + @property + def agg_request(self) -> plc.aggregation.Aggregation: # noqa: D102 if self.name == "first": - request = plc.aggregation.nth_element( + return plc.aggregation.nth_element( 0, null_handling=plc.types.NullPolicy.INCLUDE ) elif self.name == "last": - request = plc.aggregation.nth_element( + return plc.aggregation.nth_element( -1, null_handling=plc.types.NullPolicy.INCLUDE ) - if request is None: - raise NotImplementedError( - f"Aggregation {self.name} in groupby" - ) # pragma: no cover; __init__ trips first - if isminmax and plc.traits.is_floating_point(self.dtype): - assert expr is not None - # Ignore nans in these groupby aggs, do this by masking - # nans in the input - expr = UnaryFunction(self.dtype, "mask_nans", (), expr) - return AggInfo([(expr, request, self)]) + else: + assert self.request is not None, "Init should have raised" + return self.request def _reduce( self, column: Column, *, request: plc.aggregation.Aggregation @@ -215,11 +189,7 @@ def _last(self, column: Column) -> Column: return Column(plc.copying.slice(column.obj, [n - 1, n])[0]) def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" if context is not ExecutionContext.FRAME: @@ -230,4 +200,4 @@ def do_evaluate( # Aggregations like quantiles may have additional children that were # preprocessed into pylibcudf requests. child = self.children[0] - return self.op(child.evaluate(df, context=context, mapping=mapping)) + return self.op(child.evaluate(df, context=context)) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/base.py b/python/cudf_polars/cudf_polars/dsl/expressions/base.py index 680d176f83f..7ed7f782f28 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/base.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/base.py @@ -16,8 +16,6 @@ from cudf_polars.dsl.nodebase import Node if TYPE_CHECKING: - from collections.abc import Mapping - from typing_extensions import Self from cudf_polars.containers import Column, DataFrame @@ -48,11 +46,7 @@ class Expr(Node["Expr"]): """Names of non-child data (not Exprs) for reconstruction.""" def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """ Evaluate this expression given a dataframe for context. @@ -63,15 +57,10 @@ def do_evaluate( DataFrame that will provide columns. context What context are we performing this evaluation in? - mapping - Substitution mapping from expressions to Columns, used to - override the evaluation of a given expression if we're - performing a simple rewritten evaluation. Notes ----- - Do not call this function directly, but rather - :meth:`evaluate` which handles the mapping lookups. + Do not call this function directly, but rather :meth:`evaluate`. Returns ------- @@ -89,11 +78,7 @@ def do_evaluate( ) # pragma: no cover; translation of unimplemented nodes trips first def evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """ Evaluate this expression given a dataframe for context. @@ -104,10 +89,6 @@ def evaluate( DataFrame that will provide columns. context What context are we performing this evaluation in? - mapping - Substitution mapping from expressions to Columns, used to - override the evaluation of a given expression if we're - performing a simple rewritten evaluation. Notes ----- @@ -126,37 +107,28 @@ def evaluate( are returned during translation to the IR, but for now we are not perfect. """ - if mapping is None: - return self.do_evaluate(df, context=context, mapping=mapping) - try: - return mapping[self] - except KeyError: - return self.do_evaluate(df, context=context, mapping=mapping) - - def collect_agg(self, *, depth: int) -> AggInfo: - """ - Collect information about aggregations in groupbys. + return self.do_evaluate(df, context=context) - Parameters - ---------- - depth - The depth of aggregating (reduction or sampling) - expressions we are currently at. + @property + def agg_request(self) -> plc.aggregation.Aggregation: + """ + The aggregation for this expression in a grouped aggregation. Returns ------- - Aggregation info describing the expression to aggregate in the - groupby. + Aggregation request. Default is to collect the expression. + + Notes + ----- + This presumes that the IR translation has decomposed groupby + reductions only into cases we can handle. Raises ------ NotImplementedError - If we can't currently perform the aggregation request, for - example nested aggregations like ``a.max().min()``. + If requesting an aggregation from an unexpected expression. """ - raise NotImplementedError( - f"Collecting aggregation info for {type(self).__name__}" - ) # pragma: no cover; check_agg trips first + return plc.aggregation.collect_list() class ErrorExpr(Expr): @@ -168,7 +140,7 @@ def __init__(self, dtype: plc.DataType, error: str) -> None: self.dtype = dtype self.error = error self.children = () - self.is_pointwise = True + self.is_pointwise = False class NamedExpr: @@ -204,11 +176,7 @@ def __ne__(self, other: Any) -> bool: return not self.__eq__(other) def evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """ Evaluate this expression given a dataframe for context. @@ -219,8 +187,6 @@ def evaluate( DataFrame providing context context Execution context - mapping - Substitution mapping Returns ------- @@ -231,13 +197,7 @@ def evaluate( :meth:`Expr.evaluate` for details, this function just adds the name to a column produced from an expression. """ - return self.value.evaluate(df, context=context, mapping=mapping).rename( - self.name - ) - - def collect_agg(self, *, depth: int) -> AggInfo: - """Collect information about aggregations in groupbys.""" - return self.value.collect_agg(depth=depth) + return self.value.evaluate(df, context=context).rename(self.name) def reconstruct(self, expr: Expr) -> Self: """ @@ -270,21 +230,13 @@ def __init__(self, dtype: plc.DataType, name: str) -> None: self.children = () def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" # Deliberately remove the name here so that we guarantee # evaluation of the IR produces names. return df.column_map[self.name].rename(None) - def collect_agg(self, *, depth: int) -> AggInfo: - """Collect information about aggregations in groupbys.""" - return AggInfo([(self, plc.aggregation.collect_list(), self)]) - class ColRef(Expr): __slots__ = ("index", "table_ref") @@ -308,11 +260,7 @@ def __init__( self.children = (column,) def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" raise NotImplementedError( diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py b/python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py index 84fd179aedd..477307b82ea 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py @@ -13,11 +13,9 @@ import pylibcudf as plc from cudf_polars.containers import Column -from cudf_polars.dsl.expressions.base import AggInfo, ExecutionContext, Expr +from cudf_polars.dsl.expressions.base import ExecutionContext, Expr if TYPE_CHECKING: - from collections.abc import Mapping - from cudf_polars.containers import DataFrame __all__ = ["BinOp"] @@ -85,17 +83,10 @@ def __init__( } def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" - left, right = ( - child.evaluate(df, context=context, mapping=mapping) - for child in self.children - ) + left, right = (child.evaluate(df, context=context) for child in self.children) lop = left.obj rop = right.obj if left.size != right.size: @@ -106,30 +97,3 @@ def do_evaluate( return Column( plc.binaryop.binary_operation(lop, rop, self.op, self.dtype), ) - - def collect_agg(self, *, depth: int) -> AggInfo: - """Collect information about aggregations in groupbys.""" - if depth == 1: - # inside aggregation, need to pre-evaluate, - # groupby construction has checked that we don't have - # nested aggs, so stop the recursion and return ourselves - # for pre-eval - return AggInfo([(self, plc.aggregation.collect_list(), self)]) - else: - left_info, right_info = ( - child.collect_agg(depth=depth) for child in self.children - ) - requests = [*left_info.requests, *right_info.requests] - # TODO: Hack, if there were no reductions inside this - # binary expression then we want to pre-evaluate and - # collect ourselves. Otherwise we want to collect the - # aggregations inside and post-evaluate. This is a bad way - # of checking that we are in case 1. - if all( - agg.kind() == plc.aggregation.Kind.COLLECT_LIST - for _, agg, _ in requests - ): - return AggInfo([(self, plc.aggregation.collect_list(), self)]) - return AggInfo( - [*left_info.requests, *right_info.requests], - ) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py index e5696865439..c30ba511bc0 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py @@ -21,8 +21,6 @@ ) if TYPE_CHECKING: - from collections.abc import Mapping - from typing_extensions import Self import polars.type_aliases as pl_types @@ -145,11 +143,7 @@ def _distinct( } def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" if self.name in ( @@ -162,7 +156,7 @@ def do_evaluate( if child.dtype.id() not in (plc.TypeId.FLOAT32, plc.TypeId.FLOAT64): value = plc.Scalar.from_py(is_finite) return Column(plc.Column.from_scalar(value, df.num_rows)) - needles = child.evaluate(df, context=context, mapping=mapping) + needles = child.evaluate(df, context=context) to_search = [-float("inf"), float("inf")] if is_finite: # NaN is neither finite not infinite @@ -177,10 +171,7 @@ def do_evaluate( if is_finite: result = plc.unary.unary_operation(result, plc.unary.UnaryOperator.NOT) return Column(result) - columns = [ - child.evaluate(df, context=context, mapping=mapping) - for child in self.children - ] + columns = [child.evaluate(df, context=context) for child in self.children] # Kleene logic for Any (OR) and All (AND) if ignore_nulls is # False if self.name in (BooleanFunction.Name.Any, BooleanFunction.Name.All): diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py index 2e12291acb4..2b1816f36d4 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py @@ -17,8 +17,6 @@ from cudf_polars.dsl.expressions.base import ExecutionContext, Expr if TYPE_CHECKING: - from collections.abc import Mapping - from typing_extensions import Self from polars.polars import _expr_nodes as pl_expr @@ -137,17 +135,10 @@ def __init__( raise NotImplementedError("ToString is not supported on duration types") def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" - columns = [ - child.evaluate(df, context=context, mapping=mapping) - for child in self.children - ] + columns = [child.evaluate(df, context=context) for child in self.children] (column,) = columns if self.name is TemporalFunction.Name.CastTimeUnit: (unit,) = self.options diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/literal.py b/python/cudf_polars/cudf_polars/dsl/expressions/literal.py index b2007bcc6f0..4cb53758133 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/literal.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/literal.py @@ -6,15 +6,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NoReturn import pylibcudf as plc from cudf_polars.containers import Column -from cudf_polars.dsl.expressions.base import AggInfo, ExecutionContext, Expr +from cudf_polars.dsl.expressions.base import ExecutionContext, Expr if TYPE_CHECKING: - from collections.abc import Hashable, Mapping + from collections.abc import Hashable import pyarrow as pa @@ -36,19 +36,17 @@ def __init__(self, dtype: plc.DataType, value: pa.Scalar[Any]) -> None: self.is_pointwise = True def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" # datatype of pyarrow scalar is correct by construction. return Column(plc.Column.from_scalar(plc.interop.from_arrow(self.value), 1)) - def collect_agg(self, *, depth: int) -> AggInfo: - """Collect information about aggregations in groupbys.""" - return AggInfo([]) + @property + def agg_request(self) -> NoReturn: # noqa: D102 + raise NotImplementedError( + "Not expecting to require agg request of literal" + ) # pragma: no cover class LiteralColumn(Expr): @@ -70,16 +68,14 @@ def get_hashable(self) -> Hashable: return (type(self), self.dtype, id(self.value)) def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" # datatype of pyarrow array is correct by construction. return Column(plc.interop.from_arrow(self.value)) - def collect_agg(self, *, depth: int) -> AggInfo: - """Collect information about aggregations in groupbys.""" - return AggInfo([]) + @property + def agg_request(self) -> NoReturn: # noqa: D102 + raise NotImplementedError( + "Not expecting to require agg request of literal" + ) # pragma: no cover diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/selection.py b/python/cudf_polars/cudf_polars/dsl/expressions/selection.py index d87b6585df5..36e9e1c5ba8 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/selection.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/selection.py @@ -14,8 +14,6 @@ from cudf_polars.dsl.expressions.base import ExecutionContext, Expr if TYPE_CHECKING: - from collections.abc import Mapping - from cudf_polars.containers import DataFrame __all__ = ["Filter", "Gather"] @@ -31,16 +29,11 @@ def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr) -> None: self.is_pointwise = False def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" values, indices = ( - child.evaluate(df, context=context, mapping=mapping) - for child in self.children + child.evaluate(df, context=context) for child in self.children ) lo, hi = plc.reduce.minmax(indices.obj) lo = plc.interop.to_arrow(lo).as_py() @@ -68,20 +61,13 @@ class Filter(Expr): def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr): self.dtype = dtype self.children = (values, indices) - self.is_pointwise = True + self.is_pointwise = False def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" - values, mask = ( - child.evaluate(df, context=context, mapping=mapping) - for child in self.children - ) + values, mask = (child.evaluate(df, context=context) for child in self.children) table = plc.stream_compaction.apply_boolean_mask( plc.Table([values.obj]), mask.obj ) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/slicing.py b/python/cudf_polars/cudf_polars/dsl/expressions/slicing.py index 72ce90966db..ea855deabf1 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/slicing.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/slicing.py @@ -14,8 +14,6 @@ ) if TYPE_CHECKING: - from collections.abc import Mapping - import pylibcudf as plc from cudf_polars.containers import Column, DataFrame @@ -41,13 +39,9 @@ def __init__( self.children = (column,) def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" (child,) = self.children - column = child.evaluate(df, context=context, mapping=mapping) + column = child.evaluate(df, context=context) return column.slice((self.offset, self.length)) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py b/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py index acc81e26a8e..d12b2f5bce2 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py @@ -15,8 +15,6 @@ from cudf_polars.utils import sorting if TYPE_CHECKING: - from collections.abc import Mapping - from cudf_polars.containers import DataFrame __all__ = ["Sort", "SortBy"] @@ -35,15 +33,11 @@ def __init__( self.is_pointwise = False def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" (child,) = self.children - column = child.evaluate(df, context=context, mapping=mapping) + column = child.evaluate(df, context=context) (stable, nulls_last, descending) = self.options order, null_order = sorting.sort_order( [descending], nulls_last=[nulls_last], num_keys=1 @@ -75,17 +69,10 @@ def __init__( self.is_pointwise = False def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" - column, *by = ( - child.evaluate(df, context=context, mapping=mapping) - for child in self.children - ) + column, *by = (child.evaluate(df, context=context) for child in self.children) (stable, nulls_last, descending) = self.options order, null_order = sorting.sort_order( descending, nulls_last=nulls_last, num_keys=len(by) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/string.py b/python/cudf_polars/cudf_polars/dsl/expressions/string.py index d7d1fbce7b8..fd7d76313d4 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/string.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/string.py @@ -21,8 +21,6 @@ from cudf_polars.dsl.expressions.literal import Literal, LiteralColumn if TYPE_CHECKING: - from collections.abc import Mapping - from typing_extensions import Self from polars.polars import _expr_nodes as pl_expr @@ -107,7 +105,7 @@ def __init__( self.options = options self.name = name self.children = children - self.is_pointwise = True + self.is_pointwise = self.name != StringFunction.Name.ConcatVertical self._validate_input() def _validate_input(self) -> None: @@ -203,16 +201,12 @@ def _validate_input(self) -> None: ) def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" if self.name is StringFunction.Name.ConcatVertical: (child,) = self.children - column = child.evaluate(df, context=context, mapping=mapping) + column = child.evaluate(df, context=context) delimiter, ignore_nulls = self.options if column.null_count > 0 and not ignore_nulls: return Column(plc.Column.all_null_like(column.obj, 1)) @@ -225,11 +219,11 @@ def do_evaluate( ) elif self.name is StringFunction.Name.Contains: child, arg = self.children - column = child.evaluate(df, context=context, mapping=mapping) + column = child.evaluate(df, context=context) literal, _ = self.options if literal: - pat = arg.evaluate(df, context=context, mapping=mapping) + pat = arg.evaluate(df, context=context) pattern = ( pat.obj_scalar if pat.is_scalar and pat.size != column.size @@ -245,7 +239,7 @@ def do_evaluate( assert isinstance(expr_offset, Literal) assert isinstance(expr_length, Literal) - column = child.evaluate(df, context=context, mapping=mapping) + column = child.evaluate(df, context=context) # libcudf slices via [start,stop). # polars slices with offset + length where start == offset # stop = start + length. Negative values for start look backward @@ -275,9 +269,7 @@ def do_evaluate( StringFunction.Name.StripCharsStart, StringFunction.Name.StripCharsEnd, }: - column, chars = ( - c.evaluate(df, context=context, mapping=mapping) for c in self.children - ) + column, chars = (c.evaluate(df, context=context) for c in self.children) if self.name is StringFunction.Name.StripCharsStart: side = plc.strings.SideType.LEFT elif self.name is StringFunction.Name.StripCharsEnd: @@ -286,10 +278,7 @@ def do_evaluate( side = plc.strings.SideType.BOTH return Column(plc.strings.strip.strip(column.obj, side, chars.obj_scalar)) - columns = [ - child.evaluate(df, context=context, mapping=mapping) - for child in self.children - ] + columns = [child.evaluate(df, context=context) for child in self.children] if self.name is StringFunction.Name.Lowercase: (column,) = columns return Column(plc.strings.case.to_lower(column.obj)) @@ -319,7 +308,7 @@ def do_evaluate( elif self.name is StringFunction.Name.Strptime: # TODO: ignores ambiguous format, strict, exact, cache = self.options - col = self.children[0].evaluate(df, context=context, mapping=mapping) + col = self.children[0].evaluate(df, context=context) is_timestamps = plc.strings.convert.convert_datetime.is_timestamp( col.obj, format diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/ternary.py b/python/cudf_polars/cudf_polars/dsl/expressions/ternary.py index 120ca8edce0..f59b818d013 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/ternary.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/ternary.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 # TODO: remove need for this # ruff: noqa: D101 @@ -17,8 +17,6 @@ ) if TYPE_CHECKING: - from collections.abc import Mapping - from cudf_polars.containers import DataFrame @@ -37,16 +35,11 @@ def __init__( self.is_pointwise = True def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" when, then, otherwise = ( - child.evaluate(df, context=context, mapping=mapping) - for child in self.children + child.evaluate(df, context=context) for child in self.children ) then_obj = then.obj_scalar if then.is_scalar else then.obj otherwise_obj = otherwise.obj_scalar if otherwise.is_scalar else otherwise.obj diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/unary.py b/python/cudf_polars/cudf_polars/dsl/expressions/unary.py index 9194cd0421e..4b61901d963 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/unary.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/unary.py @@ -10,13 +10,11 @@ import pylibcudf as plc from cudf_polars.containers import Column -from cudf_polars.dsl.expressions.base import AggInfo, ExecutionContext, Expr +from cudf_polars.dsl.expressions.base import ExecutionContext, Expr from cudf_polars.dsl.expressions.literal import Literal from cudf_polars.utils import dtypes if TYPE_CHECKING: - from collections.abc import Mapping - from cudf_polars.containers import DataFrame __all__ = ["Cast", "Len", "UnaryFunction"] @@ -38,23 +36,13 @@ def __init__(self, dtype: plc.DataType, value: Expr) -> None: ) def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" (child,) = self.children - column = child.evaluate(df, context=context, mapping=mapping) + column = child.evaluate(df, context=context) return column.astype(self.dtype) - def collect_agg(self, *, depth: int) -> AggInfo: - """Collect information about aggregations in groupbys.""" - # TODO: Could do with sort-based groupby and segmented filter - (child,) = self.children - return child.collect_agg(depth=depth) - class Len(Expr): """Class representing the length of an expression.""" @@ -65,11 +53,7 @@ def __init__(self, dtype: plc.DataType) -> None: self.is_pointwise = False def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" return Column( @@ -79,12 +63,9 @@ def do_evaluate( ) ) - def collect_agg(self, *, depth: int) -> AggInfo: - """Collect information about aggregations in groupbys.""" - # TODO: polars returns a uint, not an int for count - return AggInfo( - [(None, plc.aggregation.count(plc.types.NullPolicy.INCLUDE), self)] - ) + @property + def agg_request(self) -> plc.aggregation.Aggregation: # noqa: D102 + return plc.aggregation.count(plc.types.NullPolicy.INCLUDE) class UnaryFunction(Expr): @@ -165,22 +146,15 @@ def __init__( ) def do_evaluate( - self, - df: DataFrame, - *, - context: ExecutionContext = ExecutionContext.FRAME, - mapping: Mapping[Expr, Column] | None = None, + self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME ) -> Column: """Evaluate this expression given a dataframe for context.""" if self.name == "mask_nans": (child,) = self.children - return child.evaluate(df, context=context, mapping=mapping).mask_nans() + return child.evaluate(df, context=context).mask_nans() if self.name == "round": (decimal_places,) = self.options - (values,) = ( - child.evaluate(df, context=context, mapping=mapping) - for child in self.children - ) + (values,) = (child.evaluate(df, context=context) for child in self.children) return Column( plc.round.round( values.obj, decimal_places, plc.round.RoundingMethod.HALF_UP @@ -188,10 +162,7 @@ def do_evaluate( ).sorted_like(values) elif self.name == "unique": (maintain_order,) = self.options - (values,) = ( - child.evaluate(df, context=context, mapping=mapping) - for child in self.children - ) + (values,) = (child.evaluate(df, context=context) for child in self.children) # Only one column, so keep_any is the same as keep_first # for stable distinct keep = plc.stream_compaction.DuplicateKeepOption.KEEP_ANY @@ -221,10 +192,7 @@ def do_evaluate( return Column(column).sorted_like(values) return Column(column) elif self.name == "set_sorted": - (column,) = ( - child.evaluate(df, context=context, mapping=mapping) - for child in self.children - ) + (column,) = (child.evaluate(df, context=context) for child in self.children) (asc,) = self.options order = ( plc.types.Order.ASCENDING @@ -248,34 +216,33 @@ def do_evaluate( null_order=null_order, ) elif self.name == "drop_nulls": - (column,) = ( - child.evaluate(df, context=context, mapping=mapping) - for child in self.children - ) + (column,) = (child.evaluate(df, context=context) for child in self.children) + if column.null_count == 0: + return column return Column( plc.stream_compaction.drop_nulls( plc.Table([column.obj]), [0], 1 ).columns()[0] ) elif self.name == "fill_null": - column = self.children[0].evaluate(df, context=context, mapping=mapping) + column = self.children[0].evaluate(df, context=context) + if column.null_count == 0: + return column if isinstance(self.children[1], Literal): arg = plc.interop.from_arrow(self.children[1].value) else: - evaluated = self.children[1].evaluate( - df, context=context, mapping=mapping - ) + evaluated = self.children[1].evaluate(df, context=context) arg = evaluated.obj_scalar if evaluated.is_scalar else evaluated.obj return Column(plc.replace.replace_nulls(column.obj, arg)) elif self.name in self._OP_MAPPING: - column = self.children[0].evaluate(df, context=context, mapping=mapping) + column = self.children[0].evaluate(df, context=context) if column.obj.type().id() != self.dtype.id(): arg = plc.unary.cast(column.obj, self.dtype) else: arg = column.obj return Column(plc.unary.unary_operation(arg, self._OP_MAPPING[self.name])) elif self.name in UnaryFunction._supported_cum_aggs: - column = self.children[0].evaluate(df, context=context, mapping=mapping) + column = self.children[0].evaluate(df, context=context) plc_col = column.obj col_type = column.obj.type() # cum_sum casts @@ -321,16 +288,3 @@ def do_evaluate( raise NotImplementedError( f"Unimplemented unary function {self.name=}" ) # pragma: no cover; init trips first - - def collect_agg(self, *, depth: int) -> AggInfo: - """Collect information about aggregations in groupbys.""" - if self.name in {"unique", "drop_nulls"} | self._supported_cum_aggs: - raise NotImplementedError(f"{self.name} in groupby") - if depth == 1: - # inside aggregation, need to pre-evaluate, groupby - # construction has checked that we don't have nested aggs, - # so stop the recursion and return ourselves for pre-eval - return AggInfo([(self, plc.aggregation.collect_list(), self)]) - else: - (child,) = self.children - return child.collect_agg(depth=depth) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 04daf7c6028..dc04904f841 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -904,42 +904,19 @@ def do_evaluate( class GroupBy(IR): """Perform a groupby.""" - class AggInfos: - """Serializable wrapper for GroupBy aggregation info.""" - - agg_requests: Sequence[expr.NamedExpr] - agg_infos: Sequence[expr.AggInfo] - - def __init__(self, agg_requests: Sequence[expr.NamedExpr]): - self.agg_requests = tuple(agg_requests) - self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests] - - def __reduce__(self) -> tuple[Any, ...]: - """Pickle an AggInfos object.""" - return (type(self), (self.agg_requests,)) - - class GroupbyOptions: - """Serializable wrapper for polars GroupbyOptions.""" - - def __init__(self, polars_groupby_options: Any): - self.dynamic = polars_groupby_options.dynamic - self.rolling = polars_groupby_options.rolling - self.slice = polars_groupby_options.slice - __slots__ = ( - "agg_infos", "agg_requests", "config_options", "keys", "maintain_order", - "options", + "zlice", ) _non_child = ( "schema", "keys", "agg_requests", "maintain_order", - "options", + "zlice", "config_options", ) keys: tuple[expr.NamedExpr, ...] @@ -948,8 +925,8 @@ def __init__(self, polars_groupby_options: Any): """Aggregation expressions.""" maintain_order: bool """Preserve order in groupby.""" - options: GroupbyOptions - """Arbitrary options.""" + zlice: Zlice | None + """Optional slice to apply after grouping.""" config_options: ConfigOptions """GPU-specific configuration options""" @@ -959,7 +936,7 @@ def __init__( keys: Sequence[expr.NamedExpr], agg_requests: Sequence[expr.NamedExpr], maintain_order: bool, # noqa: FBT001 - options: Any, + zlice: Zlice | None, config_options: ConfigOptions, df: IR, ): @@ -967,61 +944,23 @@ def __init__( self.keys = tuple(keys) self.agg_requests = tuple(agg_requests) self.maintain_order = maintain_order - self.options = self.GroupbyOptions(options) + self.zlice = zlice self.config_options = config_options self.children = (df,) - if self.options.rolling: - raise NotImplementedError( - "rolling window/groupby" - ) # pragma: no cover; rollingwindow constructor has already raised - if self.options.dynamic: - raise NotImplementedError("dynamic group by") - if any(GroupBy.check_agg(a.value) > 1 for a in self.agg_requests): - raise NotImplementedError("Nested aggregations in groupby") self._non_child_args = ( self.keys, self.agg_requests, maintain_order, - self.options, - self.AggInfos(self.agg_requests), + self.zlice, ) - @staticmethod - def check_agg(agg: expr.Expr) -> int: - """ - Determine if we can handle an aggregation expression. - - Parameters - ---------- - agg - Expression to check - - Returns - ------- - depth of nesting - - Raises - ------ - NotImplementedError - For unsupported expression nodes. - """ - if isinstance(agg, (expr.BinOp, expr.Cast, expr.UnaryFunction)): - return max(GroupBy.check_agg(child) for child in agg.children) - elif isinstance(agg, expr.Agg): - return 1 + max(GroupBy.check_agg(child) for child in agg.children) - elif isinstance(agg, (expr.Len, expr.Col, expr.Literal, expr.LiteralColumn)): - return 0 - else: - raise NotImplementedError(f"No handler for {agg=}") - @classmethod def do_evaluate( cls, keys_in: Sequence[expr.NamedExpr], agg_requests: Sequence[expr.NamedExpr], maintain_order: bool, # noqa: FBT001 - options: GroupbyOptions, - agg_info_wrapper: AggInfos, + zlice: Zlice | None, df: DataFrame, ) -> DataFrame: """Evaluate and return a dataframe.""" @@ -1038,32 +977,35 @@ def do_evaluate( column_order=[k.order for k in keys], null_precedence=[k.null_order for k in keys], ) - # TODO: uniquify requests = [] - replacements: list[expr.Expr] = [] - for info in agg_info_wrapper.agg_infos: - for pre_eval, req, rep in info.requests: - if pre_eval is None: - # A count aggregation, doesn't touch the column, - # but we need to have one. Rather than evaluating - # one, just use one of the key columns. - col = keys[0].obj - else: - col = pre_eval.evaluate(df).obj - requests.append(plc.groupby.GroupByRequest(col, [req])) - replacements.append(rep) + names = [] + for request in agg_requests: + name = request.name + value = request.value + if isinstance(value, expr.Len): + # A count aggregation, we need a column so use a key column + col = keys[0].obj + elif isinstance(value, expr.Agg): + (child,) = value.children + col = child.evaluate(df).obj + else: + # Anything else, we pre-evaluate + col = value.evaluate(df).obj + requests.append(plc.groupby.GroupByRequest(col, [value.agg_request])) + names.append(name) group_keys, raw_tables = grouper.aggregate(requests) - raw_columns: list[Column] = [] - for i, table in enumerate(raw_tables): - (column,) = table.columns() - raw_columns.append(Column(column, name=f"tmp{i}")) - mapping = dict(zip(replacements, raw_columns, strict=True)) + results = [ + Column(column, name=name) + for name, column in zip( + names, + itertools.chain.from_iterable(t.columns() for t in raw_tables), + strict=True, + ) + ] result_keys = [ Column(grouped_key, name=key.name) for key, grouped_key in zip(keys, group_keys.columns(), strict=True) ] - result_subs = DataFrame(raw_columns) - results = [req.evaluate(result_subs, mapping=mapping) for req in agg_requests] broadcasted = broadcast(*result_keys, *results) # Handle order preservation of groups if maintain_order and not sorted: @@ -1106,7 +1048,7 @@ def do_evaluate( ordered_table.columns(), broadcasted, strict=True ) ] - return DataFrame(broadcasted).slice(options.slice) + return DataFrame(broadcasted).slice(zlice) class ConditionalJoin(IR): diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 6350853e3c4..66227eea4aa 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -22,7 +22,8 @@ from cudf_polars.dsl import expr, ir from cudf_polars.dsl.to_ast import insert_colrefs -from cudf_polars.typing import NodeTraverser +from cudf_polars.dsl.utils.groupby import rewrite_groupby +from cudf_polars.typing import Schema from cudf_polars.utils import config, dtypes, sorting if TYPE_CHECKING: @@ -168,6 +169,7 @@ class set_node(AbstractContextManager[None]): __slots__ = ("n", "visitor") visitor: NodeTraverser + n: int def __init__(self, visitor: NodeTraverser, n: int) -> None: @@ -187,18 +189,14 @@ def __exit__(self, *args: Any) -> None: @singledispatch -def _translate_ir( - node: Any, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _translate_ir(node: Any, translator: Translator, schema: Schema) -> ir.IR: raise NotImplementedError( f"Translation for {type(node).__name__}" ) # pragma: no cover @_translate_ir.register -def _( - node: pl_ir.PythonScan, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.PythonScan, translator: Translator, schema: Schema) -> ir.IR: scan_fn, with_columns, source_type, predicate, nrows = node.options options = (scan_fn, with_columns, source_type, nrows) predicate = ( @@ -208,9 +206,7 @@ def _( @_translate_ir.register -def _( - node: pl_ir.Scan, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.Scan, translator: Translator, schema: Schema) -> ir.IR: typ, *options = node.scan_type if typ == "ndjson": (reader_options,) = map(json.loads, options) @@ -248,18 +244,14 @@ def _( @_translate_ir.register -def _( - node: pl_ir.Cache, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.Cache, translator: Translator, schema: Schema) -> ir.IR: return ir.Cache( schema, node.id_, node.cache_hits, translator.translate_ir(n=node.input) ) @_translate_ir.register -def _( - node: pl_ir.DataFrameScan, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.DataFrameScan, translator: Translator, schema: Schema) -> ir.IR: return ir.DataFrameScan( schema, node.df, @@ -269,9 +261,7 @@ def _( @_translate_ir.register -def _( - node: pl_ir.Select, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.Select, translator: Translator, schema: Schema) -> ir.IR: with set_node(translator.visitor, node.input): inp = translator.translate_ir(n=None) exprs = [translate_named_expr(translator, n=e) for e in node.expr] @@ -279,28 +269,25 @@ def _( @_translate_ir.register -def _( - node: pl_ir.GroupBy, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.GroupBy, translator: Translator, schema: Schema) -> ir.IR: with set_node(translator.visitor, node.input): inp = translator.translate_ir(n=None) - aggs = [translate_named_expr(translator, n=e) for e in node.aggs] keys = [translate_named_expr(translator, n=e) for e in node.keys] - return ir.GroupBy( - schema, - keys, - aggs, - node.maintain_order, - node.options, - translator.config_options, - inp, - ) + original_aggs = [translate_named_expr(translator, n=e) for e in node.aggs] + is_rolling = node.options.rolling is not None + is_dynamic = node.options.dynamic is not None + if is_dynamic: + raise NotImplementedError("group_by_dynamic") + elif is_rolling: + raise NotImplementedError("group_by_rolling") + else: + return rewrite_groupby( + node, schema, keys, original_aggs, translator.config_options, inp + ) @_translate_ir.register -def _( - node: pl_ir.Join, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR: # Join key dtypes are dependent on the schema of the left and # right inputs, so these must be translated with the relevant # input active. @@ -372,9 +359,7 @@ def _( @_translate_ir.register -def _( - node: pl_ir.HStack, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.HStack, translator: Translator, schema: Schema) -> ir.IR: with set_node(translator.visitor, node.input): inp = translator.translate_ir(n=None) exprs = [translate_named_expr(translator, n=e) for e in node.exprs] @@ -383,7 +368,7 @@ def _( @_translate_ir.register def _( - node: pl_ir.Reduce, translator: Translator, schema: dict[str, plc.DataType] + node: pl_ir.Reduce, translator: Translator, schema: Schema ) -> ir.IR: # pragma: no cover; polars doesn't emit this node yet with set_node(translator.visitor, node.input): inp = translator.translate_ir(n=None) @@ -392,9 +377,7 @@ def _( @_translate_ir.register -def _( - node: pl_ir.Distinct, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.Distinct, translator: Translator, schema: Schema) -> ir.IR: (keep, subset, maintain_order, zlice) = node.options keep = ir.Distinct._KEEP_MAP[keep] subset = frozenset(subset) if subset is not None else None @@ -409,9 +392,7 @@ def _( @_translate_ir.register -def _( - node: pl_ir.Sort, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.Sort, translator: Translator, schema: Schema) -> ir.IR: with set_node(translator.visitor, node.input): inp = translator.translate_ir(n=None) by = [translate_named_expr(translator, n=e) for e in node.by_column] @@ -423,18 +404,14 @@ def _( @_translate_ir.register -def _( - node: pl_ir.Slice, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.Slice, translator: Translator, schema: Schema) -> ir.IR: return ir.Slice( schema, node.offset, node.len, translator.translate_ir(n=node.input) ) @_translate_ir.register -def _( - node: pl_ir.Filter, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.Filter, translator: Translator, schema: Schema) -> ir.IR: with set_node(translator.visitor, node.input): inp = translator.translate_ir(n=None) mask = translate_named_expr(translator, n=node.predicate) @@ -442,18 +419,12 @@ def _( @_translate_ir.register -def _( - node: pl_ir.SimpleProjection, - translator: Translator, - schema: dict[str, plc.DataType], -) -> ir.IR: +def _(node: pl_ir.SimpleProjection, translator: Translator, schema: Schema) -> ir.IR: return ir.Projection(schema, translator.translate_ir(n=node.input)) @_translate_ir.register -def _( - node: pl_ir.MergeSorted, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.MergeSorted, translator: Translator, schema: Schema) -> ir.IR: key = node.key inp_left = translator.translate_ir(n=node.input_left) inp_right = translator.translate_ir(n=node.input_right) @@ -466,9 +437,7 @@ def _( @_translate_ir.register -def _( - node: pl_ir.MapFunction, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.MapFunction, translator: Translator, schema: Schema) -> ir.IR: name, *options = node.function return ir.MapFunction( schema, @@ -479,18 +448,14 @@ def _( @_translate_ir.register -def _( - node: pl_ir.Union, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.Union, translator: Translator, schema: Schema) -> ir.IR: return ir.Union( schema, node.options, *(translator.translate_ir(n=n) for n in node.inputs) ) @_translate_ir.register -def _( - node: pl_ir.HConcat, translator: Translator, schema: dict[str, plc.DataType] -) -> ir.IR: +def _(node: pl_ir.HConcat, translator: Translator, schema: Schema) -> ir.IR: return ir.HConcat( schema, False, # noqa: FBT003 diff --git a/python/cudf_polars/cudf_polars/dsl/utils/aggregations.py b/python/cudf_polars/cudf_polars/dsl/utils/aggregations.py new file mode 100644 index 00000000000..f99c5b623fe --- /dev/null +++ b/python/cudf_polars/cudf_polars/dsl/utils/aggregations.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +"""Utilities for rewriting aggregations.""" + +from __future__ import annotations + +import itertools +from functools import partial +from typing import TYPE_CHECKING + +import pyarrow as pa + +import pylibcudf as plc + +from cudf_polars.dsl import expr, ir + +if TYPE_CHECKING: + from collections.abc import Callable, Generator, Iterable, Sequence + + from cudf_polars.typing import Schema + +__all__ = [ + "apply_pre_evaluation", + "decompose_aggs", + "decompose_single_agg", +] + + +def decompose_single_agg( + named_expr: expr.NamedExpr, + name_generator: Generator[str, None, None], + *, + is_top: bool, +) -> tuple[list[expr.NamedExpr], expr.NamedExpr, bool]: + """ + Decompose a single named aggregation. + + Parameters + ---------- + named_expr + The named aggregation to decompose + name_generator + Generator of unique names for temporaries introduced during decomposition. + is_top + Is this the top of an aggregation expression? + + Returns + ------- + aggregations + Expressions to apply as grouped aggregations (whose children + may be evaluated pointwise). + post_aggregate + Single expression to apply to post-process the grouped + aggregations. + is_nested + Flag indicating whether processing in the inner expression + itself requires aggregations. + + Raises + ------ + NotImplementedError + If the expression contains nested aggregations or unsupported + operations in a grouped aggregation context. + """ + agg = named_expr.value + name = named_expr.name + if isinstance(agg, expr.Col): + return [named_expr], named_expr, False + if isinstance(agg, expr.Len): + return [named_expr], named_expr.reconstruct(expr.Col(agg.dtype, name)), True + if isinstance(agg, (expr.Literal, expr.LiteralColumn)): + return [], named_expr, False + if isinstance(agg, expr.Agg): + (child,) = agg.children + needs_masking = agg.name in {"min", "max"} and plc.traits.is_floating_point( + child.dtype + ) + if needs_masking and agg.options: + # pl.col("a").nan_max or nan_min + raise NotImplementedError("Nan propagation in groupby for min/max") + _, _, has_agg = decompose_single_agg( + expr.NamedExpr(next(name_generator), child), name_generator, is_top=False + ) + if has_agg: + raise NotImplementedError("Nested aggs in groupby not supported") + if needs_masking: + child = expr.UnaryFunction(child.dtype, "mask_nans", (), child) + # The aggregation is just reconstructed with the new + # (potentially masked) child. This is safe because we recursed + # to ensure there are no nested aggregations. + return ( + [named_expr.reconstruct(agg.reconstruct([child]))], + named_expr.reconstruct(expr.Col(agg.dtype, name)), + True, + ) + elif agg.name == "sum": + col = ( + expr.Cast(agg.dtype, expr.Col(plc.DataType(plc.TypeId.INT64), name)) + if ( + plc.traits.is_integral(agg.dtype) + and agg.dtype.id() != plc.TypeId.INT64 + ) + else expr.Col(agg.dtype, name) + ) + if is_top: + # In polars sum(empty_group) => 0, but in libcudf sum(empty_group) => null + # So must post-process by replacing nulls, but only if we're a "top-level" agg. + rep = expr.Literal( + agg.dtype, pa.scalar(0, type=plc.interop.to_arrow(agg.dtype)) + ) + return ( + [named_expr], + named_expr.reconstruct( + expr.UnaryFunction(agg.dtype, "fill_null", (), col, rep) + ), + True, + ) + else: + return [named_expr], expr.NamedExpr(name, col), True + else: + return [named_expr], named_expr.reconstruct(expr.Col(agg.dtype, name)), True + if isinstance(agg, expr.Ternary): + raise NotImplementedError("Ternary inside groupby") + if agg.is_pointwise: + aggs, posts, has_aggs = _decompose_aggs( + (expr.NamedExpr(next(name_generator), child) for child in agg.children), + name_generator, + is_top=False, + ) + if any(has_aggs): + # Any pointwise expression can be handled either by + # post-evaluation (if outside an aggregation). + return ( + aggs, + named_expr.reconstruct(agg.reconstruct([p.value for p in posts])), + True, + ) + else: + # Or pre-evaluation if inside an aggregation. + return ( + [named_expr], + named_expr.reconstruct(expr.Col(agg.dtype, name)), + False, + ) + raise NotImplementedError(f"No support for {type(agg)} in groupby") + + +def _decompose_aggs( + aggs: Iterable[expr.NamedExpr], + name_generator: Generator[str, None, None], + *, + is_top: bool, +) -> tuple[list[expr.NamedExpr], Sequence[expr.NamedExpr], Sequence[bool]]: + new_aggs, post, has_aggs = zip( + *(decompose_single_agg(agg, name_generator, is_top=is_top) for agg in aggs), + strict=True, + ) + return ( + list(itertools.chain.from_iterable(new_aggs)), + post, + has_aggs, + ) + + +def decompose_aggs( + aggs: Iterable[expr.NamedExpr], name_generator: Generator[str, None, None] +) -> tuple[list[expr.NamedExpr], Sequence[expr.NamedExpr]]: + """ + Process arbitrary aggregations into a form we can handle in grouped aggregations. + + Parameters + ---------- + aggs + List of aggregation expressions + name_generator + Generator of unique names for temporaries introduced during decomposition. + + Returns + ------- + aggregations + Aggregations to apply in the groupby node. + post_aggregations + Expressions to apply after aggregating (as a ``Select``). + + Notes + ----- + The aggregation expressions are guaranteed to either be + expressions that can be pointwise evaluated before the groupby + operation, or aggregations of such expressions. + + Raises + ------ + NotImplementedError + For unsupported aggregation combinations. + """ + new_aggs, post, _ = _decompose_aggs(aggs, name_generator, is_top=True) + return new_aggs, post + + +def apply_pre_evaluation( + output_schema: Schema, + inp: ir.IR, + keys: Sequence[expr.NamedExpr], + original_aggs: Sequence[expr.NamedExpr], + name_generator: Generator[str, None, None], + *extra_columns: expr.NamedExpr, +) -> tuple[ir.IR, Sequence[expr.NamedExpr], Schema, Callable[[ir.IR], ir.IR]]: + """ + Apply pre-evaluation to aggregations in a grouped or rolling context. + + Parameters + ---------- + output_schema + Schema of the plan node we're rewriting. + inp + The input to the grouped/rolling aggregation. + keys + Grouping keys (may be empty). + original_aggs + Aggregation expressions to rewrite. + name_generator + Generator of unique names for temporaries introduced during decomposition. + extra_columns + Any additional columns to be included in the output (only + relevant for rolling aggregations). Columns will appear in the + order `keys, extra_columns, original_aggs`. + + Returns + ------- + new_input + Rewritten input, suitable as input to the aggregation node + aggregations + The required aggregations. + schema + The new schema of the aggregation node + post_process + Function to apply to the aggregation node to apply any + post-processing. + + Raises + ------ + NotImplementedError + If the aggregations are somehow unsupported. + """ + aggs, post = decompose_aggs(original_aggs, name_generator) + assert len(post) == len(original_aggs), ( + f"Unexpected number of post-aggs {len(post)=} {len(original_aggs)=}" + ) + # Order-preserving unique + aggs = list(dict.fromkeys(aggs).keys()) + if any(not isinstance(e.value, expr.Col) for e in post): + selection = [ + *(key.reconstruct(expr.Col(key.value.dtype, key.name)) for key in keys), + *extra_columns, + *post, + ] + inter_schema = { + e.name: e.value.dtype for e in itertools.chain(keys, extra_columns, aggs) + } + return ( + inp, + aggs, + inter_schema, + partial(ir.Select, output_schema, selection, True), # noqa: FBT003 + ) + else: + return inp, aggs, output_schema, lambda inp: inp diff --git a/python/cudf_polars/cudf_polars/dsl/utils/groupby.py b/python/cudf_polars/cudf_polars/dsl/utils/groupby.py new file mode 100644 index 00000000000..4b4c9dac30e --- /dev/null +++ b/python/cudf_polars/cudf_polars/dsl/utils/groupby.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +"""Utilities for grouped aggregations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from cudf_polars.dsl import ir +from cudf_polars.dsl.utils.aggregations import apply_pre_evaluation +from cudf_polars.dsl.utils.naming import unique_names + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Any + + import pylibcudf as plc + + from cudf_polars.dsl import expr + from cudf_polars.utils import config + +__all__ = ["rewrite_groupby"] + + +def rewrite_groupby( + node: Any, + schema: dict[str, plc.DataType], + keys: Sequence[expr.NamedExpr], + aggs: Sequence[expr.NamedExpr], + config_options: config.ConfigOptions, + inp: ir.IR, +) -> ir.IR: + """ + Rewrite a groupby plan node into something we can handle. + + Parameters + ---------- + node + The polars groupby plan node. + schema + Schema of the groupby plan node. + keys + Grouping keys. + aggs + Originally requested aggregations. + config_options + Configuration options. + inp + Input plan node to the groupby. + + Returns + ------- + New plan node representing the grouped aggregations. + + Raises + ------ + NotImplementedError + If any of the requested aggregations are unsupported. + + Notes + ----- + Since libcudf can only perform grouped aggregations on columns + (not arbitrary expressions), the approach is to split each + aggregation into a pre-selection phase (evaluating expressions + that live within an aggregation), the aggregation phase (now + acting on columns only), and a post-selection phase (evaluating + expressions of aggregated results). + + This does scheme does not permit nested aggregations, so those are + unsupported. + """ + if len(aggs) == 0: + # TODO: use Distinct when the partitioned executor supports it + return ir.GroupBy( + schema, + keys, + [], + node.maintain_order, + node.options.slice, + config_options, + inp, + ) + inp, aggs, group_schema, apply_post_evaluation = apply_pre_evaluation( + schema, inp, keys, aggs, unique_names(schema.keys()) + ) + # TODO: use Distinct when the partitioned executor supports it if + # the requested aggregations are empty + inp = ir.GroupBy( + group_schema, + keys, + aggs, + node.maintain_order, + node.options.slice, + config_options, + inp, + ) + return apply_post_evaluation(inp) diff --git a/python/cudf_polars/cudf_polars/experimental/expressions.py b/python/cudf_polars/cudf_polars/experimental/expressions.py index ec0c8e3982b..2218fb39ca9 100644 --- a/python/cudf_polars/cudf_polars/experimental/expressions.py +++ b/python/cudf_polars/cudf_polars/experimental/expressions.py @@ -52,11 +52,11 @@ if TYPE_CHECKING: from collections.abc import Generator, MutableMapping, Sequence - from typing import Any, TypeAlias + from typing import TypeAlias from cudf_polars.dsl.expressions.base import Expr from cudf_polars.dsl.ir import IR - from cudf_polars.typing import GenericTransformer + from cudf_polars.typing import GenericTransformer, Schema from cudf_polars.utils.config import ConfigOptions @@ -369,7 +369,7 @@ def _decompose( # TODO: Check that we aren't concatenating misaligned # columns that cannot be broadcasted. For example, what # if one of the columns is sorted? - schema: MutableMapping[str, Any] = {} + schema: Schema = {} for ir in unique_input_irs: schema.update(ir.schema) input_ir = HConcat( diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index fe9da9abff0..a5411313e7c 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -6,23 +6,14 @@ import itertools import math -import uuid -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pylibcudf as plc -from cudf_polars.dsl.expr import ( - Agg, - BinOp, - Cast, - Col, - Len, - Literal, - NamedExpr, - UnaryFunction, -) +from cudf_polars.dsl.expr import Agg, BinOp, Col, Len, NamedExpr from cudf_polars.dsl.ir import GroupBy, Select from cudf_polars.dsl.traversal import traversal +from cudf_polars.dsl.utils.naming import unique_names from cudf_polars.experimental.base import PartitionInfo from cudf_polars.experimental.dispatch import lower_ir_node from cudf_polars.experimental.repartition import Repartition @@ -30,7 +21,7 @@ from cudf_polars.experimental.utils import _lower_ir_fallback if TYPE_CHECKING: - from collections.abc import MutableMapping + from collections.abc import Generator, MutableMapping from cudf_polars.dsl.expr import Expr from cudf_polars.dsl.ir import IR @@ -56,6 +47,8 @@ def combine( ------- Unified groupby-aggregation decomposition. """ + if len(decompositions) == 0: + return [], [], [] selections, aggregations, reductions = zip(*decompositions, strict=True) assert all(isinstance(ne, NamedExpr) for ne in selections) return ( @@ -66,7 +59,7 @@ def combine( def decompose( - name: str, expr: Expr + name: str, expr: Expr, *, names: Generator[str, None, None] ) -> tuple[NamedExpr, list[NamedExpr], list[NamedExpr]]: """ Decompose a groupby-aggregation expression. @@ -77,6 +70,8 @@ def decompose( Output schema name. expr The aggregation expression for a single column. + names + Generator of unique names for temporaries. Returns ------- @@ -88,32 +83,11 @@ def decompose( The reduction expressions. """ dtype = expr.dtype - expr = expr.children[0] if isinstance(expr, Cast) else expr - - unary_op: list[Any] = [] - if isinstance(expr, UnaryFunction) and expr.is_pointwise: - # TODO: Handle multiple/sequential unary ops - unary_op = [expr.name, expr.options] - expr = expr.children[0] - - def _wrap_unary(select: Expr) -> Expr: - # Helper function to wrap the final selection - # in a UnaryFunction (when necessary) - if unary_op: - return UnaryFunction(select.dtype, *unary_op, select) - return select if isinstance(expr, Len): - selection = NamedExpr(name, _wrap_unary(Col(dtype, name))) + selection = NamedExpr(name, Col(dtype, name)) aggregation = [NamedExpr(name, expr)] - reduction = [ - NamedExpr( - name, - # Sum reduction may require casting. - # Do it for all cases to be safe (for now) - Cast(dtype, Agg(dtype, "sum", None, Col(dtype, name))), - ) - ] + reduction = [NamedExpr(name, Agg(dtype, "sum", None, Col(dtype, name)))] return selection, aggregation, reduction if isinstance(expr, Agg): if expr.name in ("sum", "count", "min", "max"): @@ -121,31 +95,23 @@ def _wrap_unary(select: Expr) -> Expr: aggfunc = "sum" else: aggfunc = expr.name - selection = NamedExpr(name, _wrap_unary(Col(dtype, name))) + selection = NamedExpr(name, Col(dtype, name)) aggregation = [NamedExpr(name, expr)] - reduction = [ - NamedExpr( - name, - # Sum reduction may require casting. - # Do it for all cases to be safe (for now) - Cast(dtype, Agg(dtype, aggfunc, None, Col(dtype, name))), - ) - ] + reduction = [NamedExpr(name, Agg(dtype, aggfunc, None, Col(dtype, name)))] return selection, aggregation, reduction elif expr.name == "mean": (child,) = expr.children - token = str(uuid.uuid4().hex) # prevent collisions with user's names (sum, count), aggregations, reductions = combine( - decompose(f"{name}__mean_sum_{token}", Agg(dtype, "sum", None, child)), - decompose(f"{name}__mean_count_{token}", Len(dtype)), + decompose( + f"{next(names)}__mean_sum", + Agg(dtype, "sum", None, child), + names=names, + ), + decompose(f"{next(names)}__mean_count", Len(dtype), names=names), ) selection = NamedExpr( name, - _wrap_unary( - BinOp( - dtype, plc.binaryop.BinaryOperator.DIV, sum.value, count.value - ) - ), + BinOp(dtype, plc.binaryop.BinaryOperator.DIV, sum.value, count.value), ) return selection, aggregations, reductions else: @@ -154,31 +120,6 @@ def _wrap_unary(select: Expr) -> Expr: f"for this aggregation type:\n{type(expr)}\n" f"Only {_GB_AGG_SUPPORTED} are supported." ) - elif isinstance(expr, BinOp): - # The expectation is that each operand of the BinOp is decomposable. - # We can then combine the decompositions of the operands to form the - # decomposition of the BinOp. - (left, right) = expr.children - token = str(uuid.uuid4().hex) # prevent collisions with user's names - (left_selection, right_selection), aggregations, reductions = combine( - decompose(f"{name}__left_{token}", left), - decompose(f"{name}__right_{token}", right), - ) - - selection = NamedExpr( - name, - _wrap_unary( - BinOp(dtype, expr.op, left_selection.value, right_selection.value) - ), - ) - return selection, aggregations, reductions - - elif isinstance(expr, Literal): - selection = NamedExpr(name, _wrap_unary(Col(dtype, name))) - aggregation = [] - reduction = [NamedExpr(name, expr)] - return selection, aggregation, reduction - else: # pragma: no cover # Unsupported expression raise NotImplementedError( @@ -234,10 +175,14 @@ def _( 1, ) + name_generator = unique_names(ir.schema.keys()) # Decompose the aggregation requests into three distinct phases try: selection_exprs, piecewise_exprs, reduction_exprs = combine( - *(decompose(agg.name, agg.value) for agg in ir.agg_requests) + *( + decompose(agg.name, agg.value, names=name_generator) + for agg in ir.agg_requests + ) ) except NotImplementedError: return _lower_ir_fallback( @@ -253,7 +198,7 @@ def _( ir.keys, piecewise_exprs, ir.maintain_order, - ir.options, + None, ir.config_options, child, ) @@ -300,7 +245,7 @@ def _( ir.keys, reduction_exprs, ir.maintain_order, - ir.options, + None, ir.config_options, gb_inter, ) @@ -312,20 +257,18 @@ def _( ir.keys, reduction_exprs, ir.maintain_order, - ir.options, + ir.zlice, ir.config_options, gb_inter, ) partition_info[gb_reduce] = PartitionInfo(count=post_aggregation_count) # Final Select phase - aggregated = {ne.name: ne for ne in selection_exprs} new_node = Select( ir.schema, [ - # Select the aggregated data or the original column - aggregated.get(name, NamedExpr(name, Col(dtype, name))) - for name, dtype in ir.schema.items() + *(NamedExpr(k.name, Col(k.value.dtype, k.name)) for k in ir.keys), + *selection_exprs, ], False, # noqa: FBT003 gb_reduce, diff --git a/python/cudf_polars/cudf_polars/testing/plugin.py b/python/cudf_polars/cudf_polars/testing/plugin.py index 7a763c6bab8..b41b35e8259 100644 --- a/python/cudf_polars/cudf_polars/testing/plugin.py +++ b/python/cudf_polars/cudf_polars/testing/plugin.py @@ -129,7 +129,6 @@ def pytest_configure(config: pytest.Config) -> None: "tests/unit/lazyframe/test_lazyframe.py::test_cast_frame": "Casting that raises not supported on GPU", "tests/unit/lazyframe/test_lazyframe.py::test_lazy_cache_hit": "Debug output on stderr doesn't match", "tests/unit/operations/aggregation/test_aggregations.py::test_duration_function_literal": "Broadcasting inside groupby-agg not supported", - "tests/unit/operations/aggregation/test_aggregations.py::test_sum_empty_and_null_set": "libcudf sums column of all nulls to null, not zero", "tests/unit/operations/aggregation/test_aggregations.py::test_binary_op_agg_context_no_simplify_expr_12423": "groupby-agg of just literals should not produce collect_list", "tests/unit/operations/aggregation/test_aggregations.py::test_nan_inf_aggregation": "treatment of nans and nulls together is different in libcudf and polars in groupby-agg context", "tests/unit/operations/arithmetic/test_list_arithmetic.py::test_list_arithmetic_values[func0-func0-none]": "cudf-polars doesn't nullify division by zero", @@ -173,7 +172,6 @@ def pytest_configure(config: pytest.Config) -> None: "tests/unit/operations/test_group_by.py::test_group_by_binary_agg_with_literal": "Incorrect broadcasting of literals in groupby-agg", "tests/unit/operations/test_group_by.py::test_group_by_lit_series": "Incorrect broadcasting of literals in groupby-agg", "tests/unit/operations/test_group_by.py::test_aggregated_scalar_elementwise_15602": "Unsupported boolean function/dtype combination in groupby-agg", - "tests/unit/operations/test_group_by.py::test_schemas[data1-expr1-expected_select1-expected_gb1]": "Mismatching dtypes, needs cudf#15852", "tests/unit/operations/test_join.py::test_cross_join_slice_pushdown": "Need to implement slice pushdown for cross joins", "tests/unit/sql/test_cast.py::test_cast_errors[values0-values::uint8-conversion from `f64` to `u64` failed]": "Casting that raises not supported on GPU", "tests/unit/sql/test_cast.py::test_cast_errors[values1-values::uint4-conversion from `i64` to `u32` failed]": "Casting that raises not supported on GPU", @@ -184,8 +182,6 @@ def pytest_configure(config: pytest.Config) -> None: "tests/unit/test_cse.py::test_cse_predicate_self_join": "Debug output on stderr doesn't match", "tests/unit/test_empty.py::test_empty_9137": "Mismatching dtypes, needs cudf#15852", "tests/unit/test_errors.py::test_error_on_empty_group_by": "Incorrect exception raised", - # Maybe flaky, order-dependent? - "tests/unit/test_queries.py::test_group_by_agg_equals_zero_3535": "libcudf sums all nulls to null, not zero", } diff --git a/python/cudf_polars/cudf_polars/typing/__init__.py b/python/cudf_polars/cudf_polars/typing/__init__.py index 86882bdc3c0..b69a7da6736 100644 --- a/python/cudf_polars/cudf_polars/typing/__init__.py +++ b/python/cudf_polars/cudf_polars/typing/__init__.py @@ -5,7 +5,7 @@ from __future__ import annotations -from collections.abc import Hashable, Mapping, MutableMapping +from collections.abc import Hashable, MutableMapping from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, TypedDict, Union from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir @@ -13,7 +13,7 @@ import pylibcudf as plc if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Mapping from typing import TypeAlias import polars as pl @@ -22,6 +22,9 @@ from cudf_polars.dsl import expr, ir, nodebase __all__: list[str] = [ + "ColumnHeader", + "ColumnOptions", + "DataFrameHeader", "ExprTransformer", "GenericTransformer", "IRTransformer", @@ -29,6 +32,8 @@ "OptimizationArgs", "PolarsExpr", "PolarsIR", + "Schema", + "Slice", ] PolarsIR: TypeAlias = Union[ @@ -67,7 +72,7 @@ pl_expr.PyExprIR, ] -Schema: TypeAlias = Mapping[str, plc.DataType] +Schema: TypeAlias = dict[str, plc.DataType] Slice: TypeAlias = tuple[int, int | None] @@ -89,7 +94,7 @@ def view_current_node(self) -> PolarsIR: """Convert current plan node to python rep.""" ... - def get_schema(self) -> Mapping[str, pl.DataType]: + def get_schema(self) -> Schema: """Get the schema of the current plan node.""" ... diff --git a/python/cudf_polars/tests/test_groupby.py b/python/cudf_polars/tests/test_groupby.py index 53b96ba574b..e93feac854d 100644 --- a/python/cudf_polars/tests/test_groupby.py +++ b/python/cudf_polars/tests/test_groupby.py @@ -3,6 +3,7 @@ from __future__ import annotations import itertools +from datetime import date import numpy as np import pytest @@ -22,7 +23,23 @@ def df(): "key1": [1, 1, 1, 2, 3, 1, 4, 6, 7], "key2": [2, 2, 2, 2, 6, 1, 4, 6, 8], "int": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "int32": pl.Series([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=pl.Int32()), + "uint16_with_null": pl.Series( + [1, None, 2, None, None, None, 4, 5, 6], dtype=pl.UInt16() + ), "float": [7.0, 1, 2, 3, 4, 5, 6, 7, 8], + "string": ["abc", "def", "hijk", "lmno", "had", "to", "be", "or", "not"], + "datetime": [ + date(1970, 1, 1), + date(1972, 1, 10), + date(2000, 1, 1), + date(2004, 12, 1), + date(2004, 10, 1), + date(1971, 2, 1), + date(2003, 12, 1), + date(2001, 1, 1), + date(1999, 12, 31), + ], } ) @@ -45,15 +62,29 @@ def keys(request): @pytest.fixture( params=[ + [], ["int"], ["float", "int"], [pl.col("float") + pl.col("int")], + [pl.col("float").is_not_null()], + [pl.col("int32").sum()], + [pl.col("int32").mean()], + [ + pl.col("uint16_with_null").sum(), + pl.col("uint16_with_null").mean().alias("mean"), + ], [pl.col("float").max() - pl.col("int").min()], [pl.col("float").mean(), pl.col("int").std()], [(pl.col("float") - pl.lit(2)).max()], + [pl.lit(10).alias("literal_value")], [pl.col("float").sum().round(decimals=1)], [pl.col("float").round(decimals=1).sum()], [pl.col("int").first(), pl.col("float").last()], + [pl.col("int").sum(), pl.col("string").str.replace("h", "foo", literal=True)], + [ + pl.col("datetime").max(), + pl.col("datetime").max().dt.is_leap_year().alias("leapyear"), + ], ], ids=lambda aggs: "-".join(map(str, aggs)), ) @@ -112,8 +143,10 @@ def test_groupby_len(df, keys): @pytest.mark.parametrize( "expr", [ - pl.col("float").is_not_null(), (pl.col("int").max() + pl.col("float").min()).max(), + pl.when(pl.col("int") < pl.lit(2)) + .then(pl.col("float").sum()) + .otherwise(pl.lit(-2)), ], ) def test_groupby_unsupported(df, expr): diff --git a/python/cudf_polars/tests/test_rolling.py b/python/cudf_polars/tests/test_rolling.py new file mode 100644 index 00000000000..1a97d85e42d --- /dev/null +++ b/python/cudf_polars/tests/test_rolling.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from datetime import datetime + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import assert_ir_translation_raises + + +@pytest.fixture +def df(): + dates = pl.Series( + [ + datetime.strptime("2020-01-01 13:45:48", "%Y-%m-%d %H:%M:%S"), + datetime.strptime("2020-01-01 16:42:13", "%Y-%m-%d %H:%M:%S"), + datetime.strptime("2020-01-01 16:45:09", "%Y-%m-%d %H:%M:%S"), + datetime.strptime("2020-01-02 18:12:48", "%Y-%m-%d %H:%M:%S"), + datetime.strptime("2020-01-03 19:45:32", "%Y-%m-%d %H:%M:%S"), + datetime.strptime("2020-01-08 23:16:43", "%Y-%m-%d %H:%M:%S"), + datetime.strptime("2020-01-10 23:16:43", "%Y-%m-%d %H:%M:%S"), + ], + dtype=pl.Datetime(time_unit="us"), + ) + return pl.LazyFrame( + { + "dt": dates, + "values": [3, 7, 5, 9, 2, 1, 72], + "floats": pl.Series( + [float("nan"), 7, 5, 2, -10, 1, float("inf")], dtype=pl.Float64() + ), + } + ) + + +@pytest.mark.parametrize("closed", ["left", "right", "both", "none"]) +@pytest.mark.parametrize("period", ["1w4d", "48h", "180s"]) +def test_datetime_rolling(df, closed, period): + q = df.rolling("dt", period=period, closed=closed).agg( + sum_a=pl.sum("values"), + min_a=pl.min("values"), + max_a=pl.max("values"), + ) + + assert_ir_translation_raises(q, NotImplementedError) + + +def test_calendrical_period_unsupported(df): + q = df.rolling("dt", period="1m", closed="right").agg(sum=pl.sum("values")) + + assert_ir_translation_raises(q, NotImplementedError)