|
2 | 2 |
|
3 | 3 | import sys
|
4 | 4 | from collections.abc import Callable, Generator, Iterator
|
5 |
| -from typing import TYPE_CHECKING, Any, Generic, TypeVar, get_args, get_origin, get_type_hints |
| 5 | +from typing import TYPE_CHECKING, Any, Generic, TypeVar, get_args, get_origin, get_type_hints, overload |
6 | 6 |
|
7 | 7 | from daft.daft import row_wise_udf
|
8 | 8 | from daft.datatype import DataType
|
@@ -56,13 +56,20 @@ def __init__(self, fn: Callable[P, Iterator[T]], return_dtype: DataTypeLike | No
|
56 | 56 | return_dtype = args[0]
|
57 | 57 | self.return_dtype = DataType._infer_type(return_dtype)
|
58 | 58 |
|
59 |
| - def eval(self, *args: P.args, **kwargs: P.kwargs) -> Iterator[T]: |
60 |
| - """Run the decorated generator function eagerly and return an iterator.""" |
61 |
| - return self._inner(*args, **kwargs) |
| 59 | + @overload |
| 60 | + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Iterator[T]: ... |
| 61 | + @overload |
| 62 | + def __call__(self, *args: Expression, **kwargs: Expression) -> Expression: ... |
| 63 | + @overload |
| 64 | + def __call__(self, *args: Any, **kwargs: Any) -> Expression | Iterator[T]: ... |
62 | 65 |
|
63 |
| - def __call__(self, *args: Any, **kwargs: Any) -> Expression: |
| 66 | + def __call__(self, *args: Any, **kwargs: Any) -> Expression | Iterator[T]: |
64 | 67 | expr_args = get_expr_args(args, kwargs)
|
65 | 68 |
|
| 69 | + # evaluate the function eagerly if there are no expression arguments |
| 70 | + if len(expr_args) == 0: |
| 71 | + return self._inner(*args, **kwargs) |
| 72 | + |
66 | 73 | # temporary workaround before we implement actual generator UDFs: convert it into a list-type row-wise UDF + explode
|
67 | 74 | def inner_rowwise(*args: P.args, **kwargs: P.kwargs) -> list[T]:
|
68 | 75 | return list(self._inner(*args, **kwargs))
|
|
0 commit comments