Skip to content

Commit dd5342f

Browse files
kevinzwangvenkateshdb
authored andcommitted
feat!: revert daft.func behavior on literal arguments (Eventual-Inc#5087)
## Changes Made Reverts the changes in Eventual-Inc#4998 for v0.6 after further discussion. ## Related Issues <!-- Link to related GitHub issues, e.g., "Closes Eventual-Inc#123" --> ## Checklist - [x] Documented in API Docs (if applicable) - [x] Documented in User Guide (if applicable) - [x] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [x] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review)
1 parent 9dc5891 commit dd5342f

File tree

5 files changed

+31
-16
lines changed

5 files changed

+31
-16
lines changed

daft/udf/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ class _DaftFuncDecorator:
4848
- **Async row-wise** (1 row in, 1 row out) - created by decorating a Python async function
4949
- **Generator** (1 row in, N rows out) - created by decorating a Python generator function
5050
51-
Decorated functions accept both their original argument types and Daft Expressions, and return an Expression for lazy evaluation.
52-
To run the original function, call `<your_function>.eval(<args>)`.
51+
Decorated functions accept both their original argument types and Daft Expressions.
52+
When any arguments are Expressions, they return a Daft Expression that can be used in DataFrame operations.
53+
When called with their original arguments, they execute immediately and the behavior is the same as if the function was not decorated.
5354
5455
Args:
5556
return_dtype: The data type that this function should return or yield. If not specified, it is derived from the function's return type hint.

daft/udf/generator.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import sys
44
from collections.abc import Callable, Generator, Iterator
5-
from typing import TYPE_CHECKING, Any, Generic, TypeVar, get_args, get_origin, get_type_hints
5+
from typing import TYPE_CHECKING, Any, Generic, TypeVar, get_args, get_origin, get_type_hints, overload
66

77
from daft.daft import row_wise_udf
88
from daft.datatype import DataType
@@ -56,13 +56,20 @@ def __init__(self, fn: Callable[P, Iterator[T]], return_dtype: DataTypeLike | No
5656
return_dtype = args[0]
5757
self.return_dtype = DataType._infer_type(return_dtype)
5858

59-
def eval(self, *args: P.args, **kwargs: P.kwargs) -> Iterator[T]:
60-
"""Run the decorated generator function eagerly and return an iterator."""
61-
return self._inner(*args, **kwargs)
59+
@overload
60+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Iterator[T]: ...
61+
@overload
62+
def __call__(self, *args: Expression, **kwargs: Expression) -> Expression: ...
63+
@overload
64+
def __call__(self, *args: Any, **kwargs: Any) -> Expression | Iterator[T]: ...
6265

63-
def __call__(self, *args: Any, **kwargs: Any) -> Expression:
66+
def __call__(self, *args: Any, **kwargs: Any) -> Expression | Iterator[T]:
6467
expr_args = get_expr_args(args, kwargs)
6568

69+
# evaluate the function eagerly if there are no expression arguments
70+
if len(expr_args) == 0:
71+
return self._inner(*args, **kwargs)
72+
6673
# temporary workaround before we implement actual generator UDFs: convert it into a list-type row-wise UDF + explode
6774
def inner_rowwise(*args: P.args, **kwargs: P.kwargs) -> list[T]:
6875
return list(self._inner(*args, **kwargs))

daft/udf/row_wise.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import sys
4-
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, get_type_hints
4+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, get_type_hints, overload
55

66
from daft.daft import row_wise_udf
77

@@ -47,13 +47,20 @@ def __init__(self, fn: Callable[P, T], return_dtype: DataTypeLike | None):
4747
return_dtype = type_hints["return"]
4848
self.return_dtype = DataType._infer_type(return_dtype)
4949

50-
def eval(self, *args: P.args, **kwargs: P.kwargs) -> T:
51-
"""Run the decorated function eagerly and return the result immediately."""
52-
return self._inner(*args, **kwargs)
50+
@overload
51+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...
52+
@overload
53+
def __call__(self, *args: Expression, **kwargs: Expression) -> Expression: ...
54+
@overload
55+
def __call__(self, *args: Any, **kwargs: Any) -> Expression | T: ...
5356

54-
def __call__(self, *args: Any, **kwargs: Any) -> Expression:
57+
def __call__(self, *args: Any, **kwargs: Any) -> Expression | T:
5558
expr_args = get_expr_args(args, kwargs)
5659

60+
# evaluate the function eagerly if there are no expression arguments
61+
if len(expr_args) == 0:
62+
return self._inner(*args, **kwargs)
63+
5764
return Expression._from_pyexpr(
5865
row_wise_udf(self.name, self._inner, self.return_dtype._dtype, (args, kwargs), expr_args)
5966
)

tests/udf/test_generator_udf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def my_repeat(to_repeat: str, n: int):
4040
for _ in range(n):
4141
yield to_repeat
4242

43-
assert list(my_repeat.eval("foo", 3)) == ["foo", "foo", "foo"]
43+
assert list(my_repeat("foo", 3)) == ["foo", "foo", "foo"]
4444

4545

4646
def test_generator_udf_typing_iterator():

tests/udf/test_row_wise_udf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,16 @@ def test_row_wise_udf_literal_eval():
9090
def my_stringify_and_sum(a: int, b: int) -> str:
9191
return f"{a + b}"
9292

93-
assert my_stringify_and_sum.eval(1, 2) == "3"
93+
assert my_stringify_and_sum(1, 2) == "3"
9494

9595

9696
def test_row_wise_udf_kwargs():
9797
@daft.func
9898
def my_stringify_and_sum_repeat(a: int, b: int, repeat: int = 1) -> str:
9999
return f"{a + b}" * repeat
100100

101-
assert my_stringify_and_sum_repeat.eval(1, 2) == "3"
102-
assert my_stringify_and_sum_repeat.eval(1, 2, 3) == "333"
101+
assert my_stringify_and_sum_repeat(1, 2) == "3"
102+
assert my_stringify_and_sum_repeat(1, 2, 3) == "333"
103103

104104
df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]})
105105
default_df = df.select(my_stringify_and_sum_repeat(col("x"), col("y")))

0 commit comments

Comments
 (0)