Skip to content

Commit

Permalink
fix(api): add support for using deferreds in the argmin/argmax `k…
Browse files Browse the repository at this point in the history
…ey` argument (#9652)

Adds support for using deferreds in the `argmin` and `argmax` `key`
argument.
  • Loading branch information
cpcloud authored Jul 22, 2024
1 parent 4f39d69 commit 3f05cbc
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
7 changes: 7 additions & 0 deletions ibis/expr/tests/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,10 @@ def test_reduction_methods(fn, operation, cond):
assert node.where == resolved
else:
assert node.where == where.op()


@pytest.mark.parametrize("func_name", ["argmin", "argmax"])
def test_argminmax_deferred(func_name):
t = ibis.table({"a": "int", "b": "int"}, name="t")
func = getattr(t.a, func_name)
assert func(_.b).equals(func(t.b))
24 changes: 9 additions & 15 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,10 +1539,7 @@ def _bind_to_parent_table(self, value) -> Value | None:
def __deferred_repr__(self):
return f"<column[{self.type()}]>"

def approx_nunique(
self,
where: ir.BooleanValue | None = None,
) -> ir.IntegerScalar:
def approx_nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
"""Return the approximate number of distinct elements in `self`.
::: {.callout-note}
Expand Down Expand Up @@ -1584,10 +1581,7 @@ def approx_nunique(
self, where=self._bind_to_parent_table(where)
).to_expr()

def approx_median(
self,
where: ir.BooleanValue | None = None,
) -> Scalar:
def approx_median(self, where: ir.BooleanValue | None = None) -> Scalar:
"""Return an approximate of the median of `self`.
::: {.callout-note}
Expand Down Expand Up @@ -1744,7 +1738,9 @@ def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
└─────────────┘
"""
return ops.ArgMax(
self, key=key, where=self._bind_to_parent_table(where)
self,
key=self._bind_to_parent_table(key),
where=self._bind_to_parent_table(where),
).to_expr()

def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
Expand Down Expand Up @@ -1778,7 +1774,9 @@ def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
└──────────┘
"""
return ops.ArgMin(
self, key=key, where=self._bind_to_parent_table(where)
self,
key=self._bind_to_parent_table(key),
where=self._bind_to_parent_table(where),
).to_expr()

def median(self, where: ir.BooleanValue | None = None) -> Scalar:
Expand Down Expand Up @@ -1941,11 +1939,7 @@ def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
self, where=self._bind_to_parent_table(where)
).to_expr()

def topk(
self,
k: int,
by: ir.Value | None = None,
) -> ir.Table:
def topk(self, k: int, by: ir.Value | None = None) -> ir.Table:
"""Return a "top k" expression.
Parameters
Expand Down

0 comments on commit 3f05cbc

Please sign in to comment.