Skip to content

Commit

Permalink
feat(api): support wider range of types in where arg to column redu…
Browse files Browse the repository at this point in the history
…ctions
  • Loading branch information
jcrist committed May 24, 2024
1 parent 7aba385 commit 582165f
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 47 deletions.
78 changes: 42 additions & 36 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,9 +1074,7 @@ def collect(self, where: ir.BooleanValue | None = None) -> ir.ArrayScalar:
│ b │ [4, 5] │
└────────┴──────────────────────┘
"""
return ops.ArrayCollect(
self, where=self._bind_reduction_filter(where)
).to_expr()
return ops.ArrayCollect(self, where=self._bind_to_parent_table(where)).to_expr()

def identical_to(self, other: Value) -> ir.BooleanValue:
"""Return whether this expression is identical to other.
Expand Down Expand Up @@ -1153,7 +1151,7 @@ def group_concat(
'39.1: 36.7'
"""
return ops.GroupConcat(
self, sep=sep, where=self._bind_reduction_filter(where)
self, sep=sep, where=self._bind_to_parent_table(where)
).to_expr()

def __hash__(self) -> int:
Expand Down Expand Up @@ -1485,24 +1483,34 @@ def as_table(self) -> ir.Table:
"base table references to a projection"
)

def _bind_reduction_filter(self, where):
rels = self.op().relations
if isinstance(where, Deferred):
if len(rels) == 0:
raise com.IbisInputError(
"Unable to bind deferred expression to a table because "
"the expression doesn't depend on any tables"
)
elif len(rels) == 1:
(table,) = rels
return where.resolve(table.to_expr())
else:
def _bind_to_parent_table(self, value) -> Value | None:
"""Bind an expr to the parent table of `self`."""
if value is None:
return None
if isinstance(value, (Deferred, str)) or callable(value):
op = self.op()
if len(op.relations) != 1:
# TODO: I don't think this line can ever be hit by a valid
# expression, since it would require a column expression to
# directly depend on multiple tables. Currently some invalid
# expressions (like t1.a.argmin(t2.b)) aren't caught at
# construction time though, so we keep the check in for now.
raise com.RelationError(
"Cannot bind deferred expression to a table because the "
"expression depends on multiple tables"
f"Unable to bind `{value!r}` - the current expression"
f"depends on multiple tables."
)
else:
return where
table = next(iter(op.relations)).to_expr()

if isinstance(value, str):
return table[value]
elif isinstance(value, Deferred):
return value.resolve(table)
else:
value = value(table)

if not isinstance(value, Value):
return literal(value)
return value

def __deferred_repr__(self):
return f"<column[{self.type()}]>"
Expand Down Expand Up @@ -1545,7 +1553,7 @@ def approx_nunique(
55
"""
return ops.ApproxCountDistinct(
self, where=self._bind_reduction_filter(where)
self, where=self._bind_to_parent_table(where)
).to_expr()

def approx_median(
Expand Down Expand Up @@ -1585,9 +1593,7 @@ def approx_median(
>>> t.body_mass_g.approx_median(where=t.species == "Chinstrap")
3700
"""
return ops.ApproxMedian(
self, where=self._bind_reduction_filter(where)
).to_expr()
return ops.ApproxMedian(self, where=self._bind_to_parent_table(where)).to_expr()

def mode(self, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the mode of a column.
Expand All @@ -1612,7 +1618,7 @@ def mode(self, where: ir.BooleanValue | None = None) -> Scalar:
>>> t.body_mass_g.mode(where=(t.species == "Gentoo") & (t.sex == "male"))
5550
"""
return ops.Mode(self, where=self._bind_reduction_filter(where)).to_expr()
return ops.Mode(self, where=self._bind_to_parent_table(where)).to_expr()

def max(self, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the maximum of a column.
Expand All @@ -1637,7 +1643,7 @@ def max(self, where: ir.BooleanValue | None = None) -> Scalar:
>>> t.body_mass_g.max(where=t.species == "Chinstrap")
4800
"""
return ops.Max(self, where=self._bind_reduction_filter(where)).to_expr()
return ops.Max(self, where=self._bind_to_parent_table(where)).to_expr()

def min(self, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the minimum of a column.
Expand All @@ -1662,7 +1668,7 @@ def min(self, where: ir.BooleanValue | None = None) -> Scalar:
>>> t.body_mass_g.min(where=t.species == "Adelie")
2850
"""
return ops.Min(self, where=self._bind_reduction_filter(where)).to_expr()
return ops.Min(self, where=self._bind_to_parent_table(where)).to_expr()

def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the value of `self` that maximizes `key`.
Expand Down Expand Up @@ -1690,7 +1696,7 @@ def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
'Chinstrap'
"""
return ops.ArgMax(
self, key=key, where=self._bind_reduction_filter(where)
self, key=key, where=self._bind_to_parent_table(where)
).to_expr()

def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
Expand Down Expand Up @@ -1720,7 +1726,7 @@ def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
'Adelie'
"""
return ops.ArgMin(
self, key=key, where=self._bind_reduction_filter(where)
self, key=key, where=self._bind_to_parent_table(where)
).to_expr()

def median(self, where: ir.BooleanValue | None = None) -> Scalar:
Expand Down Expand Up @@ -1776,7 +1782,7 @@ def median(self, where: ir.BooleanValue | None = None) -> Scalar:
│ Torgersen │ Adelie │
└───────────┴────────────────┘
"""
return ops.Median(self, where=where).to_expr()
return ops.Median(self, where=self._bind_to_parent_table(where)).to_expr()

def quantile(
self,
Expand Down Expand Up @@ -1846,7 +1852,7 @@ def quantile(
op = ops.MultiQuantile
else:
op = ops.Quantile
return op(self, quantile, where=where).to_expr()
return op(self, quantile, where=self._bind_to_parent_table(where)).to_expr()

def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
"""Compute the number of distinct rows in an expression.
Expand All @@ -1872,7 +1878,7 @@ def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
55
"""
return ops.CountDistinct(
self, where=self._bind_reduction_filter(where)
self, where=self._bind_to_parent_table(where)
).to_expr()

def topk(
Expand Down Expand Up @@ -1938,7 +1944,7 @@ def arbitrary(
removed_in="10.0",
instead="call `first` or `last` explicitly",
)
return ops.Arbitrary(self, where=self._bind_reduction_filter(where)).to_expr()
return ops.Arbitrary(self, where=self._bind_to_parent_table(where)).to_expr()

def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
"""Compute the number of rows in an expression.
Expand All @@ -1953,7 +1959,7 @@ def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
IntegerScalar
Number of elements in an expression
"""
return ops.Count(self, where=self._bind_reduction_filter(where)).to_expr()
return ops.Count(self, where=self._bind_to_parent_table(where)).to_expr()

def value_counts(self) -> ir.Table:
"""Compute a frequency table.
Expand Down Expand Up @@ -2022,7 +2028,7 @@ def first(self, where: ir.BooleanValue | None = None) -> Value:
>>> t.chars.first(where=t.chars != "a")
'b'
"""
return ops.First(self, where=self._bind_reduction_filter(where)).to_expr()
return ops.First(self, where=self._bind_to_parent_table(where)).to_expr()

def last(self, where: ir.BooleanValue | None = None) -> Value:
"""Return the last value of a column.
Expand All @@ -2048,7 +2054,7 @@ def last(self, where: ir.BooleanValue | None = None) -> Value:
>>> t.chars.last(where=t.chars != "d")
'c'
"""
return ops.Last(self, where=self._bind_reduction_filter(where)).to_expr()
return ops.Last(self, where=self._bind_to_parent_table(where)).to_expr()

def rank(self) -> ir.IntegerColumn:
"""Compute position of first element within each equal-value group in sorted order.
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/types/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def resolve_exists_subquery(outer):
if len(parents) == 2:
return Deferred(Call(resolve_exists_subquery, _))
elif len(parents) == 1:
op = ops.Any(self, where=self._bind_reduction_filter(where))
op = ops.Any(self, where=self._bind_to_parent_table(where))
else:
raise NotImplementedError(
f'Cannot compute "any" for expression of type {type(self)} '
Expand Down Expand Up @@ -407,7 +407,7 @@ def all(self, where: BooleanValue | None = None) -> BooleanScalar:
False
"""
return ops.All(self, where=self._bind_reduction_filter(where)).to_expr()
return ops.All(self, where=self._bind_to_parent_table(where)).to_expr()

def notall(self, where: BooleanValue | None = None) -> BooleanScalar:
"""Return whether not all elements are `True`.
Expand Down
18 changes: 9 additions & 9 deletions ibis/expr/types/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def std(
Standard deviation of `arg`
"""
return ops.StandardDev(
self, how=how, where=self._bind_reduction_filter(where)
self, how=how, where=self._bind_to_parent_table(where)
).to_expr()

def var(
Expand All @@ -810,7 +810,7 @@ def var(
Standard deviation of `arg`
"""
return ops.Variance(
self, how=how, where=self._bind_reduction_filter(where)
self, how=how, where=self._bind_to_parent_table(where)
).to_expr()

def corr(
Expand All @@ -836,7 +836,7 @@ def corr(
The correlation of `left` and `right`
"""
return ops.Correlation(
self, right, how=how, where=self._bind_reduction_filter(where)
self, right, how=how, where=self._bind_to_parent_table(where)
).to_expr()

def cov(
Expand All @@ -862,7 +862,7 @@ def cov(
The covariance of `self` and `right`
"""
return ops.Covariance(
self, right, how=how, where=self._bind_reduction_filter(where)
self, right, how=how, where=self._bind_to_parent_table(where)
).to_expr()

def mean(
Expand All @@ -883,7 +883,7 @@ def mean(
"""
# TODO(kszucs): remove the alias from the reduction method in favor
# of default name generated by ops.Value operations
return ops.Mean(self, where=self._bind_reduction_filter(where)).to_expr()
return ops.Mean(self, where=self._bind_to_parent_table(where)).to_expr()

def cummean(self, *, where=None, group_by=None, order_by=None) -> NumericColumn:
"""Return the cumulative mean of the input."""
Expand All @@ -907,7 +907,7 @@ def sum(
NumericScalar
The sum of the input expression
"""
return ops.Sum(self, where=self._bind_reduction_filter(where)).to_expr()
return ops.Sum(self, where=self._bind_to_parent_table(where)).to_expr()

def cumsum(self, *, where=None, group_by=None, order_by=None) -> NumericColumn:
"""Return the cumulative sum of the input."""
Expand Down Expand Up @@ -1161,15 +1161,15 @@ class IntegerScalar(NumericScalar, IntegerValue):
class IntegerColumn(NumericColumn, IntegerValue):
def bit_and(self, where: ir.BooleanValue | None = None) -> IntegerScalar:
"""Aggregate the column using the bitwise and operator."""
return ops.BitAnd(self, where=self._bind_reduction_filter(where)).to_expr()
return ops.BitAnd(self, where=self._bind_to_parent_table(where)).to_expr()

def bit_or(self, where: ir.BooleanValue | None = None) -> IntegerScalar:
"""Aggregate the column using the bitwise or operator."""
return ops.BitOr(self, where=self._bind_reduction_filter(where)).to_expr()
return ops.BitOr(self, where=self._bind_to_parent_table(where)).to_expr()

def bit_xor(self, where: ir.BooleanValue | None = None) -> IntegerScalar:
"""Aggregate the column using the bitwise exclusive or operator."""
return ops.BitXor(self, where=self._bind_reduction_filter(where)).to_expr()
return ops.BitXor(self, where=self._bind_to_parent_table(where)).to_expr()


@public
Expand Down
75 changes: 75 additions & 0 deletions ibis/tests/expr/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

import pytest

import ibis
from ibis import _


@pytest.fixture
def table():
return ibis.table(
{"ints": "int", "floats": "float", "bools": "bool", "strings": "string"}
)


@pytest.mark.parametrize(
"func",
[
pytest.param(lambda t, **kws: t.strings.arbitrary(**kws), id="arbitrary"),
pytest.param(lambda t, **kws: t.strings.collect(**kws), id="collect"),
pytest.param(lambda t, **kws: t.strings.group_concat(**kws), id="group_concat"),
pytest.param(
lambda t, **kws: t.strings.approx_nunique(**kws), id="approx_nunique"
),
pytest.param(
lambda t, **kws: t.strings.approx_median(**kws), id="approx_median"
),
pytest.param(lambda t, **kws: t.strings.mode(**kws), id="mode"),
pytest.param(lambda t, **kws: t.strings.max(**kws), id="max"),
pytest.param(lambda t, **kws: t.strings.min(**kws), id="min"),
pytest.param(lambda t, **kws: t.strings.argmax(t.ints, **kws), id="argmax"),
pytest.param(lambda t, **kws: t.strings.argmin(t.ints, **kws), id="argmin"),
pytest.param(lambda t, **kws: t.strings.median(**kws), id="median"),
pytest.param(lambda t, **kws: t.strings.quantile(0.25, **kws), id="quantile"),
pytest.param(
lambda t, **kws: t.strings.quantile([0.25, 0.75], **kws),
id="multi-quantile",
),
pytest.param(lambda t, **kws: t.strings.nunique(**kws), id="nunique"),
pytest.param(lambda t, **kws: t.strings.count(**kws), id="count"),
pytest.param(lambda t, **kws: t.strings.first(**kws), id="first"),
pytest.param(lambda t, **kws: t.strings.last(**kws), id="last"),
pytest.param(lambda t, **kws: t.ints.std(**kws), id="std"),
pytest.param(lambda t, **kws: t.ints.var(**kws), id="var"),
pytest.param(lambda t, **kws: t.ints.mean(**kws), id="mean"),
pytest.param(lambda t, **kws: t.ints.sum(**kws), id="sum"),
pytest.param(lambda t, **kws: t.ints.corr(t.floats, **kws), id="corr"),
pytest.param(lambda t, **kws: t.ints.cov(t.floats, **kws), id="cov"),
pytest.param(lambda t, **kws: t.ints.bit_and(**kws), id="bit_and"),
pytest.param(lambda t, **kws: t.ints.bit_xor(**kws), id="bit_xor"),
pytest.param(lambda t, **kws: t.ints.bit_or(**kws), id="bit_or"),
pytest.param(lambda t, **kws: t.bools.any(**kws), id="any"),
pytest.param(lambda t, **kws: t.bools.all(**kws), id="all"),
pytest.param(lambda t, **kws: t.count(**kws), id="table-count"),
pytest.param(lambda t, **kws: t.nunique(**kws), id="table-nunique"),
],
)
def test_aggregation_where(table, func):
# No where
op = func(table).op()
assert op.where is None

# Literal where
op = func(table, where=False).op()
assert op.where.equals(ibis.literal(False).op())

# Various ways to spell the same column expression
r1 = func(table, where=table.bools)
r2 = func(table, where=_.bools)
r3 = func(table, where=lambda t: t.bools)
r4 = func(table, where="bools")
assert r1.equals(r2)
assert r1.equals(r3)
assert r1.equals(r4)
assert r1.op().where.equals(table.bools.op())

0 comments on commit 582165f

Please sign in to comment.