Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add input batching for UDFs #2651

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,7 @@ def stateless_udf(
expressions: list[PyExpr],
return_dtype: PyDataType,
resource_request: ResourceRequest | None,
batch_size: int | None,
) -> PyExpr: ...
def stateful_udf(
name: str,
Expand All @@ -1196,6 +1197,7 @@ def stateful_udf(
return_dtype: PyDataType,
resource_request: ResourceRequest | None,
init_args: tuple[tuple[Any, ...], dict[str, Any]] | None,
batch_size: int | None,
) -> PyExpr: ...
def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ...
def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ...
Expand Down
22 changes: 19 additions & 3 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,12 @@ def stateless_udf(
expressions: builtins.list[Expression],
return_dtype: DataType,
resource_request: ResourceRequest | None,
batch_size: int | None,
) -> Expression:
return Expression._from_pyexpr(
_stateless_udf(name, partial, [e._expr for e in expressions], return_dtype._dtype, resource_request)
_stateless_udf(
name, partial, [e._expr for e in expressions], return_dtype._dtype, resource_request, batch_size
)
)

@staticmethod
Expand All @@ -250,10 +253,17 @@ def stateful_udf(
return_dtype: DataType,
resource_request: ResourceRequest | None,
init_args: tuple[tuple[Any, ...], dict[builtins.str, Any]] | None,
batch_size: int | None,
) -> Expression:
return Expression._from_pyexpr(
_stateful_udf(
name, partial, [e._expr for e in expressions], return_dtype._dtype, resource_request, init_args
name,
partial,
[e._expr for e in expressions],
return_dtype._dtype,
resource_request,
init_args,
batch_size,
)
)

Expand Down Expand Up @@ -819,7 +829,13 @@ def batch_func(self_series):
name = name + "."
name = name + getattr(func, "__qualname__") # type: ignore[call-overload]

return StatelessUDF(name=name, func=batch_func, return_dtype=return_dtype, resource_request=None)(self)
return StatelessUDF(
name=name,
func=batch_func,
return_dtype=return_dtype,
resource_request=None,
batch_size=None,
Vince7778 marked this conversation as resolved.
Show resolved Hide resolved
)(self)

def is_null(self) -> Expression:
"""Checks if values in the Expression are Null (a special value indicating missing data)
Expand Down
132 changes: 95 additions & 37 deletions daft/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,13 @@
return hash(frozenset(self.bound_args.arguments.items()))


# Assumes there is at least one evaluated expression
def run_udf(
func: Callable, bound_args: BoundUDFArgs, evaluated_expressions: list[Series], py_return_dtype: PyDataType
func: Callable,
bound_args: BoundUDFArgs,
evaluated_expressions: list[Series],
py_return_dtype: PyDataType,
batch_size: int | None,
) -> PySeries:
"""API to call from Rust code that will call an UDF (initialized, in the case of stateful UDFs) on the inputs"""
return_dtype = DataType._from_pydatatype(py_return_dtype)
Expand All @@ -90,52 +95,87 @@
), "Computed series must map 1:1 to the expressions that were evaluated"
function_parameter_name_to_index = {name: i for i, name in enumerate(expressions)}

args = []
for name in arg_keys:
# special-case to skip `self` since that would be a redundant argument in a method call to a class-UDF
if name == "self":
continue

assert name in pyvalues or name in function_parameter_name_to_index
if name in pyvalues:
args.append(pyvalues[name])
else:
args.append(evaluated_expressions[function_parameter_name_to_index[name]])

kwargs = {}
for name in kwarg_keys:
assert name in pyvalues or name in function_parameter_name_to_index
if name in pyvalues:
kwargs[name] = pyvalues[name]
else:
kwargs[name] = evaluated_expressions[function_parameter_name_to_index[name]]
def get_args_for_slice(start: int, end: int):
args = []
must_slice = start > 0 or end < len(evaluated_expressions[0])
for name in arg_keys:
# special-case to skip `self` since that would be a redundant argument in a method call to a class-UDF
if name == "self":
continue

assert name in pyvalues or name in function_parameter_name_to_index
if name in pyvalues:
args.append(pyvalues[name])
else:
# we fill in expressions later
series = evaluated_expressions[function_parameter_name_to_index[name]]
if must_slice:
series = series.slice(start, end)
args.append(series)

kwargs = {}
for name in kwarg_keys:
assert name in pyvalues or name in function_parameter_name_to_index
if name in pyvalues:
kwargs[name] = pyvalues[name]
else:
series = evaluated_expressions[function_parameter_name_to_index[name]]
if must_slice:
series = series.slice(start, end)
kwargs[name] = series

return args, kwargs

if batch_size is None:
args, kwargs = get_args_for_slice(0, len(evaluated_expressions[0]))
try:
results = [func(*args, **kwargs)]
except Exception as user_function_exception:
raise RuntimeError(
f"User-defined function `{func}` failed when executing on inputs with lengths: {tuple(len(series) for series in evaluated_expressions)}"
) from user_function_exception
else:
# all inputs must have the same lengths for batching
# not sure this error can possibly be triggered but it's here
if len(set(len(s) for s in evaluated_expressions)) != 1:
raise RuntimeError(

Check warning on line 141 in daft/udf.py

View check run for this annotation

Codecov / codecov/patch

daft/udf.py#L141

Added line #L141 was not covered by tests
f"User-defined function `{func}` failed: cannot run in batches when inputs are different lengths: {tuple(len(series) for series in evaluated_expressions)}"
)

try:
result = func(*args, **kwargs)
except Exception as user_function_exception:
raise RuntimeError(
f"User-defined function `{func}` failed when executing on inputs with lengths: {tuple(len(series) for series in evaluated_expressions)}"
) from user_function_exception
results = []
for i in range(0, len(evaluated_expressions[0]), batch_size):
cur_batch_size = min(batch_size, len(evaluated_expressions[0]) - i)
args, kwargs = get_args_for_slice(i, i + cur_batch_size)
try:
results.append(func(*args, **kwargs))
except Exception as user_function_exception:
raise RuntimeError(

Check warning on line 152 in daft/udf.py

View check run for this annotation

Codecov / codecov/patch

daft/udf.py#L151-L152

Added lines #L151 - L152 were not covered by tests
f"User-defined function `{func}` failed when executing on inputs with lengths: {tuple(cur_batch_size for _ in evaluated_expressions)}"
) from user_function_exception

# HACK: Series have names and the logic for naming fields/series in a UDF is to take the first
# Expression's name. Note that this logic is tied to the `to_field` implementation of the Rust PythonUDF
# and is quite error prone! If our Series naming logic here is wrong, things will break when the UDF is run on a table.
name = evaluated_expressions[0].name()

# Post-processing of results into a Series of the appropriate dtype
if isinstance(result, Series):
return result.rename(name).cast(return_dtype)._series
elif isinstance(result, list):
if isinstance(results[0], Series):
result_series = Series.concat(results)
return result_series.rename(name).cast(return_dtype)._series
elif isinstance(results[0], list):
result_list = [x for res in results for x in res]
if return_dtype == DataType.python():
return Series.from_pylist(result, name=name, pyobj="force")._series
return Series.from_pylist(result_list, name=name, pyobj="force")._series
else:
return Series.from_pylist(result, name=name, pyobj="allow").cast(return_dtype)._series
elif _NUMPY_AVAILABLE and isinstance(result, np.ndarray):
return Series.from_numpy(result, name=name).cast(return_dtype)._series
elif _PYARROW_AVAILABLE and isinstance(result, (pa.Array, pa.ChunkedArray)):
return Series.from_arrow(result, name=name).cast(return_dtype)._series
return Series.from_pylist(result_list, name=name, pyobj="allow").cast(return_dtype)._series
elif _NUMPY_AVAILABLE and isinstance(results[0], np.ndarray):
result_np = np.concatenate(results)
return Series.from_numpy(result_np, name=name).cast(return_dtype)._series
elif _PYARROW_AVAILABLE and isinstance(results[0], (pa.Array, pa.ChunkedArray)):
result_pa = pa.concat_arrays(results)
return Series.from_arrow(result_pa, name=name).cast(return_dtype)._series
else:
raise NotImplementedError(f"Return type not supported for UDF: {type(result)}")
raise NotImplementedError(f"Return type not supported for UDF: {type(results[0])}")

Check warning on line 178 in daft/udf.py

View check run for this annotation

Codecov / codecov/patch

daft/udf.py#L178

Added line #L178 was not covered by tests


# Marker that helps us differentiate whether a user provided the argument or not
Expand All @@ -145,6 +185,7 @@
@dataclasses.dataclass
class UDF:
resource_request: ResourceRequest | None
batch_size: int | None

@abstractmethod
def __call__(self, *args, **kwargs) -> Expression: ...
Expand All @@ -155,6 +196,7 @@
num_cpus: float | None = _UnsetMarker,
num_gpus: float | None = _UnsetMarker,
memory_bytes: int | None = _UnsetMarker,
batch_size: int | None = _UnsetMarker,
) -> UDF:
"""Replace the resource requests for running each instance of your stateless UDF.

Expand All @@ -180,11 +222,18 @@
the appropriate GPU to each UDF using `CUDA_VISIBLE_DEVICES`.
memory_bytes: Amount of memory to allocate each running instance of your UDF in bytes. If your UDF is experiencing out-of-memory errors,
this parameter can help hint Daft that each UDF requires a certain amount of heap memory for execution.
batch_size: Enables batching of the input into batches of at most this size. Results between batches are concatenated.
"""
result = self

# Any changes to resource request
if not all((num_cpus is _UnsetMarker, num_gpus is _UnsetMarker, memory_bytes is _UnsetMarker)):
if not all(
(
num_cpus is _UnsetMarker,
num_gpus is _UnsetMarker,
memory_bytes is _UnsetMarker,
)
):
new_resource_request = ResourceRequest() if self.resource_request is None else self.resource_request
if num_cpus is not _UnsetMarker:
new_resource_request = new_resource_request.with_num_cpus(num_cpus)
Expand All @@ -194,6 +243,9 @@
new_resource_request = new_resource_request.with_memory_bytes(memory_bytes)
result = dataclasses.replace(result, resource_request=new_resource_request)

if batch_size is not _UnsetMarker:
result.batch_size = batch_size

Vince7778 marked this conversation as resolved.
Show resolved Hide resolved
return result


Expand Down Expand Up @@ -238,6 +290,7 @@
expressions=expressions,
return_dtype=self.return_dtype,
resource_request=self.resource_request,
batch_size=self.batch_size,
)

def bind_func(self, *args, **kwargs) -> inspect.BoundArguments:
Expand Down Expand Up @@ -287,6 +340,7 @@
return_dtype=self.return_dtype,
resource_request=self.resource_request,
init_args=self.init_args,
batch_size=self.batch_size,
)

def with_init_args(self, *args, **kwargs) -> StatefulUDF:
Expand Down Expand Up @@ -356,6 +410,7 @@
num_cpus: float | None = None,
num_gpus: float | None = None,
memory_bytes: int | None = None,
batch_size: int | None = None,
) -> Callable[[UserProvidedPythonFunction | type], StatelessUDF | StatefulUDF]:
"""Decorator to convert a Python function into a UDF

Expand Down Expand Up @@ -463,6 +518,7 @@
the appropriate GPU to each UDF using `CUDA_VISIBLE_DEVICES`.
memory_bytes: Amount of memory to allocate each running instance of your UDF in bytes. If your UDF is experiencing out-of-memory errors,
this parameter can help hint Daft that each UDF requires a certain amount of heap memory for execution.
batch_size: Enables batching of the input into batches of at most this size. Results between batches are concatenated.

Returns:
Callable[[UserProvidedPythonFunction], UDF]: UDF decorator - converts a user-provided Python function as a UDF that can be called on Expressions
Expand Down Expand Up @@ -491,13 +547,15 @@
cls=f,
return_dtype=return_dtype,
resource_request=resource_request,
batch_size=batch_size,
)
else:
return StatelessUDF(
name=name,
func=f,
return_dtype=return_dtype,
resource_request=resource_request,
batch_size=batch_size,
)

return _udf
10 changes: 10 additions & 0 deletions src/daft-dsl/src/functions/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
num_expressions: usize,
pub return_dtype: DataType,
pub resource_request: Option<ResourceRequest>,
pub batch_size: Option<usize>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
Expand All @@ -51,6 +52,7 @@
pub resource_request: Option<ResourceRequest>,
#[cfg(feature = "python")]
pub init_args: Option<pyobj_serde::PyObjectWrapper>,
pub batch_size: Option<usize>,
}

#[cfg(feature = "python")]
Expand All @@ -60,6 +62,7 @@
expressions: &[ExprRef],
return_dtype: DataType,
resource_request: Option<ResourceRequest>,
batch_size: Option<usize>,
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateless(StatelessPythonUDF {
Expand All @@ -68,6 +71,7 @@
num_expressions: expressions.len(),
return_dtype,
resource_request,
batch_size,
})),
inputs: expressions.into(),
})
Expand All @@ -79,13 +83,15 @@
expressions: &[ExprRef],
return_dtype: DataType,
resource_request: Option<ResourceRequest>,
batch_size: Option<usize>,

Check warning on line 86 in src/daft-dsl/src/functions/python/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/python/mod.rs#L86

Added line #L86 was not covered by tests
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateless(StatelessPythonUDF {
name: name.to_string().into(),
num_expressions: expressions.len(),
return_dtype,
resource_request,
batch_size,

Check warning on line 94 in src/daft-dsl/src/functions/python/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/python/mod.rs#L94

Added line #L94 was not covered by tests
})),
inputs: expressions.into(),
})
Expand All @@ -99,6 +105,7 @@
return_dtype: DataType,
resource_request: Option<ResourceRequest>,
init_args: Option<pyo3::PyObject>,
batch_size: Option<usize>,
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF {
Expand All @@ -108,6 +115,7 @@
return_dtype,
resource_request,
init_args: init_args.map(pyobj_serde::PyObjectWrapper),
batch_size,
})),
inputs: expressions.into(),
})
Expand All @@ -119,13 +127,15 @@
expressions: &[ExprRef],
return_dtype: DataType,
resource_request: Option<ResourceRequest>,
batch_size: Option<usize>,

Check warning on line 130 in src/daft-dsl/src/functions/python/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/python/mod.rs#L130

Added line #L130 was not covered by tests
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF {
name: name.to_string().into(),
num_expressions: expressions.len(),
return_dtype,
resource_request,
batch_size,

Check warning on line 138 in src/daft-dsl/src/functions/python/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/python/mod.rs#L138

Added line #L138 was not covered by tests
})),
inputs: expressions.into(),
})
Expand Down
Loading
Loading