Skip to content

Commit

Permalink
[FEAT] Add input batching for UDFs (#2651)
Browse files Browse the repository at this point in the history
Lets you specify a batch size for your UDFs, and then inputs are split
up into batches of at most that size before being passed in to the UDF.
  • Loading branch information
Vince7778 authored Aug 14, 2024
1 parent eeb4191 commit ab557b5
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 64 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,7 @@ def stateless_udf(
expressions: list[PyExpr],
return_dtype: PyDataType,
resource_request: ResourceRequest | None,
batch_size: int | None,
) -> PyExpr: ...
def stateful_udf(
name: str,
Expand All @@ -1196,6 +1197,7 @@ def stateful_udf(
return_dtype: PyDataType,
resource_request: ResourceRequest | None,
init_args: tuple[tuple[Any, ...], dict[str, Any]] | None,
batch_size: int | None,
) -> PyExpr: ...
def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ...
def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ...
Expand Down
22 changes: 19 additions & 3 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,12 @@ def stateless_udf(
expressions: builtins.list[Expression],
return_dtype: DataType,
resource_request: ResourceRequest | None,
batch_size: int | None,
) -> Expression:
return Expression._from_pyexpr(
_stateless_udf(name, partial, [e._expr for e in expressions], return_dtype._dtype, resource_request)
_stateless_udf(
name, partial, [e._expr for e in expressions], return_dtype._dtype, resource_request, batch_size
)
)

@staticmethod
Expand All @@ -250,10 +253,17 @@ def stateful_udf(
return_dtype: DataType,
resource_request: ResourceRequest | None,
init_args: tuple[tuple[Any, ...], dict[builtins.str, Any]] | None,
batch_size: int | None,
) -> Expression:
return Expression._from_pyexpr(
_stateful_udf(
name, partial, [e._expr for e in expressions], return_dtype._dtype, resource_request, init_args
name,
partial,
[e._expr for e in expressions],
return_dtype._dtype,
resource_request,
init_args,
batch_size,
)
)

Expand Down Expand Up @@ -823,7 +833,13 @@ def batch_func(self_series):
name = name + "."
name = name + getattr(func, "__qualname__") # type: ignore[call-overload]

return StatelessUDF(name=name, func=batch_func, return_dtype=return_dtype, resource_request=None)(self)
return StatelessUDF(
name=name,
func=batch_func,
return_dtype=return_dtype,
resource_request=None,
batch_size=None,
)(self)

def is_null(self) -> Expression:
"""Checks if values in the Expression are Null (a special value indicating missing data)
Expand Down
132 changes: 95 additions & 37 deletions daft/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,13 @@ def __hash__(self) -> int:
return hash(frozenset(self.bound_args.arguments.items()))


# Assumes there is at least one evaluated expression
def run_udf(
func: Callable, bound_args: BoundUDFArgs, evaluated_expressions: list[Series], py_return_dtype: PyDataType
func: Callable,
bound_args: BoundUDFArgs,
evaluated_expressions: list[Series],
py_return_dtype: PyDataType,
batch_size: int | None,
) -> PySeries:
"""API to call from Rust code that will call an UDF (initialized, in the case of stateful UDFs) on the inputs"""
return_dtype = DataType._from_pydatatype(py_return_dtype)
Expand All @@ -90,52 +95,87 @@ def run_udf(
), "Computed series must map 1:1 to the expressions that were evaluated"
function_parameter_name_to_index = {name: i for i, name in enumerate(expressions)}

args = []
for name in arg_keys:
# special-case to skip `self` since that would be a redundant argument in a method call to a class-UDF
if name == "self":
continue

assert name in pyvalues or name in function_parameter_name_to_index
if name in pyvalues:
args.append(pyvalues[name])
else:
args.append(evaluated_expressions[function_parameter_name_to_index[name]])

kwargs = {}
for name in kwarg_keys:
assert name in pyvalues or name in function_parameter_name_to_index
if name in pyvalues:
kwargs[name] = pyvalues[name]
else:
kwargs[name] = evaluated_expressions[function_parameter_name_to_index[name]]
def get_args_for_slice(start: int, end: int):
args = []
must_slice = start > 0 or end < len(evaluated_expressions[0])
for name in arg_keys:
# special-case to skip `self` since that would be a redundant argument in a method call to a class-UDF
if name == "self":
continue

assert name in pyvalues or name in function_parameter_name_to_index
if name in pyvalues:
args.append(pyvalues[name])
else:
# we fill in expressions later
series = evaluated_expressions[function_parameter_name_to_index[name]]
if must_slice:
series = series.slice(start, end)
args.append(series)

kwargs = {}
for name in kwarg_keys:
assert name in pyvalues or name in function_parameter_name_to_index
if name in pyvalues:
kwargs[name] = pyvalues[name]
else:
series = evaluated_expressions[function_parameter_name_to_index[name]]
if must_slice:
series = series.slice(start, end)
kwargs[name] = series

return args, kwargs

if batch_size is None:
args, kwargs = get_args_for_slice(0, len(evaluated_expressions[0]))
try:
results = [func(*args, **kwargs)]
except Exception as user_function_exception:
raise RuntimeError(
f"User-defined function `{func}` failed when executing on inputs with lengths: {tuple(len(series) for series in evaluated_expressions)}"
) from user_function_exception
else:
# all inputs must have the same lengths for batching
# not sure this error can possibly be triggered but it's here
if len(set(len(s) for s in evaluated_expressions)) != 1:
raise RuntimeError(
f"User-defined function `{func}` failed: cannot run in batches when inputs are different lengths: {tuple(len(series) for series in evaluated_expressions)}"
)

try:
result = func(*args, **kwargs)
except Exception as user_function_exception:
raise RuntimeError(
f"User-defined function `{func}` failed when executing on inputs with lengths: {tuple(len(series) for series in evaluated_expressions)}"
) from user_function_exception
results = []
for i in range(0, len(evaluated_expressions[0]), batch_size):
cur_batch_size = min(batch_size, len(evaluated_expressions[0]) - i)
args, kwargs = get_args_for_slice(i, i + cur_batch_size)
try:
results.append(func(*args, **kwargs))
except Exception as user_function_exception:
raise RuntimeError(
f"User-defined function `{func}` failed when executing on inputs with lengths: {tuple(cur_batch_size for _ in evaluated_expressions)}"
) from user_function_exception

# HACK: Series have names and the logic for naming fields/series in a UDF is to take the first
# Expression's name. Note that this logic is tied to the `to_field` implementation of the Rust PythonUDF
# and is quite error prone! If our Series naming logic here is wrong, things will break when the UDF is run on a table.
name = evaluated_expressions[0].name()

# Post-processing of results into a Series of the appropriate dtype
if isinstance(result, Series):
return result.rename(name).cast(return_dtype)._series
elif isinstance(result, list):
if isinstance(results[0], Series):
result_series = Series.concat(results)
return result_series.rename(name).cast(return_dtype)._series
elif isinstance(results[0], list):
result_list = [x for res in results for x in res]
if return_dtype == DataType.python():
return Series.from_pylist(result, name=name, pyobj="force")._series
return Series.from_pylist(result_list, name=name, pyobj="force")._series
else:
return Series.from_pylist(result, name=name, pyobj="allow").cast(return_dtype)._series
elif _NUMPY_AVAILABLE and isinstance(result, np.ndarray):
return Series.from_numpy(result, name=name).cast(return_dtype)._series
elif _PYARROW_AVAILABLE and isinstance(result, (pa.Array, pa.ChunkedArray)):
return Series.from_arrow(result, name=name).cast(return_dtype)._series
return Series.from_pylist(result_list, name=name, pyobj="allow").cast(return_dtype)._series
elif _NUMPY_AVAILABLE and isinstance(results[0], np.ndarray):
result_np = np.concatenate(results)
return Series.from_numpy(result_np, name=name).cast(return_dtype)._series
elif _PYARROW_AVAILABLE and isinstance(results[0], (pa.Array, pa.ChunkedArray)):
result_pa = pa.concat_arrays(results)
return Series.from_arrow(result_pa, name=name).cast(return_dtype)._series
else:
raise NotImplementedError(f"Return type not supported for UDF: {type(result)}")
raise NotImplementedError(f"Return type not supported for UDF: {type(results[0])}")


# Marker that helps us differentiate whether a user provided the argument or not
Expand All @@ -145,6 +185,7 @@ def run_udf(
@dataclasses.dataclass
class UDF:
resource_request: ResourceRequest | None
batch_size: int | None

@abstractmethod
def __call__(self, *args, **kwargs) -> Expression: ...
Expand All @@ -155,6 +196,7 @@ def override_options(
num_cpus: float | None = _UnsetMarker,
num_gpus: float | None = _UnsetMarker,
memory_bytes: int | None = _UnsetMarker,
batch_size: int | None = _UnsetMarker,
) -> UDF:
"""Replace the resource requests for running each instance of your stateless UDF.
Expand All @@ -180,11 +222,18 @@ def override_options(
the appropriate GPU to each UDF using `CUDA_VISIBLE_DEVICES`.
memory_bytes: Amount of memory to allocate each running instance of your UDF in bytes. If your UDF is experiencing out-of-memory errors,
this parameter can help hint Daft that each UDF requires a certain amount of heap memory for execution.
batch_size: Enables batching of the input into batches of at most this size. Results between batches are concatenated.
"""
result = self

# Any changes to resource request
if not all((num_cpus is _UnsetMarker, num_gpus is _UnsetMarker, memory_bytes is _UnsetMarker)):
if not all(
(
num_cpus is _UnsetMarker,
num_gpus is _UnsetMarker,
memory_bytes is _UnsetMarker,
)
):
new_resource_request = ResourceRequest() if self.resource_request is None else self.resource_request
if num_cpus is not _UnsetMarker:
new_resource_request = new_resource_request.with_num_cpus(num_cpus)
Expand All @@ -194,6 +243,9 @@ def override_options(
new_resource_request = new_resource_request.with_memory_bytes(memory_bytes)
result = dataclasses.replace(result, resource_request=new_resource_request)

if batch_size is not _UnsetMarker:
result.batch_size = batch_size

return result


Expand Down Expand Up @@ -238,6 +290,7 @@ def __call__(self, *args, **kwargs) -> Expression:
expressions=expressions,
return_dtype=self.return_dtype,
resource_request=self.resource_request,
batch_size=self.batch_size,
)

def bind_func(self, *args, **kwargs) -> inspect.BoundArguments:
Expand Down Expand Up @@ -287,6 +340,7 @@ def __call__(self, *args, **kwargs) -> Expression:
return_dtype=self.return_dtype,
resource_request=self.resource_request,
init_args=self.init_args,
batch_size=self.batch_size,
)

def with_init_args(self, *args, **kwargs) -> StatefulUDF:
Expand Down Expand Up @@ -356,6 +410,7 @@ def udf(
num_cpus: float | None = None,
num_gpus: float | None = None,
memory_bytes: int | None = None,
batch_size: int | None = None,
) -> Callable[[UserProvidedPythonFunction | type], StatelessUDF | StatefulUDF]:
"""Decorator to convert a Python function into a UDF
Expand Down Expand Up @@ -463,6 +518,7 @@ def udf(
the appropriate GPU to each UDF using `CUDA_VISIBLE_DEVICES`.
memory_bytes: Amount of memory to allocate each running instance of your UDF in bytes. If your UDF is experiencing out-of-memory errors,
this parameter can help hint Daft that each UDF requires a certain amount of heap memory for execution.
batch_size: Enables batching of the input into batches of at most this size. Results between batches are concatenated.
Returns:
Callable[[UserProvidedPythonFunction], UDF]: UDF decorator - converts a user-provided Python function as a UDF that can be called on Expressions
Expand Down Expand Up @@ -491,13 +547,15 @@ def _udf(f: UserProvidedPythonFunction | type) -> StatelessUDF | StatefulUDF:
cls=f,
return_dtype=return_dtype,
resource_request=resource_request,
batch_size=batch_size,
)
else:
return StatelessUDF(
name=name,
func=f,
return_dtype=return_dtype,
resource_request=resource_request,
batch_size=batch_size,
)

return _udf
10 changes: 10 additions & 0 deletions src/daft-dsl/src/functions/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub struct StatelessPythonUDF {
num_expressions: usize,
pub return_dtype: DataType,
pub resource_request: Option<ResourceRequest>,
pub batch_size: Option<usize>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
Expand All @@ -51,6 +52,7 @@ pub struct StatefulPythonUDF {
pub resource_request: Option<ResourceRequest>,
#[cfg(feature = "python")]
pub init_args: Option<pyobj_serde::PyObjectWrapper>,
pub batch_size: Option<usize>,
}

#[cfg(feature = "python")]
Expand All @@ -60,6 +62,7 @@ pub fn stateless_udf(
expressions: &[ExprRef],
return_dtype: DataType,
resource_request: Option<ResourceRequest>,
batch_size: Option<usize>,
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateless(StatelessPythonUDF {
Expand All @@ -68,6 +71,7 @@ pub fn stateless_udf(
num_expressions: expressions.len(),
return_dtype,
resource_request,
batch_size,
})),
inputs: expressions.into(),
})
Expand All @@ -79,13 +83,15 @@ pub fn stateless_udf(
expressions: &[ExprRef],
return_dtype: DataType,
resource_request: Option<ResourceRequest>,
batch_size: Option<usize>,
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateless(StatelessPythonUDF {
name: name.to_string().into(),
num_expressions: expressions.len(),
return_dtype,
resource_request,
batch_size,
})),
inputs: expressions.into(),
})
Expand All @@ -99,6 +105,7 @@ pub fn stateful_udf(
return_dtype: DataType,
resource_request: Option<ResourceRequest>,
init_args: Option<pyo3::PyObject>,
batch_size: Option<usize>,
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF {
Expand All @@ -108,6 +115,7 @@ pub fn stateful_udf(
return_dtype,
resource_request,
init_args: init_args.map(pyobj_serde::PyObjectWrapper),
batch_size,
})),
inputs: expressions.into(),
})
Expand All @@ -119,13 +127,15 @@ pub fn stateful_udf(
expressions: &[ExprRef],
return_dtype: DataType,
resource_request: Option<ResourceRequest>,
batch_size: Option<usize>,
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF {
name: name.to_string().into(),
num_expressions: expressions.len(),
return_dtype,
resource_request,
batch_size,
})),
inputs: expressions.into(),
})
Expand Down
Loading

0 comments on commit ab557b5

Please sign in to comment.