Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions daft/udf/row_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from daft.series import Series

if TYPE_CHECKING:
from collections.abc import Awaitable

from daft.daft import PyDataType, PySeries


Expand Down Expand Up @@ -61,6 +63,34 @@ def __call__(self, *args: Any, **kwargs: Any) -> Expression | T:
return Expression._row_wise_udf(self.name, self._inner, self.return_dtype, (args, kwargs), expr_args)


def call_async_batch_with_evaluated_exprs(
fn: Callable[..., Awaitable[Any]],
return_dtype: PyDataType,
original_args: tuple[tuple[Any, ...], dict[str, Any]],
evaluated_args_list: list[list[Any]],
) -> PySeries:
import asyncio

args, kwargs = original_args

tasks = []
for evaluated_args in evaluated_args_list:
new_args = [evaluated_args.pop(0) if isinstance(arg, Expression) else arg for arg in args]
new_kwargs = {
key: (evaluated_args.pop(0) if isinstance(arg, Expression) else arg) for key, arg in kwargs.items()
}
coroutine = fn(*new_args, **new_kwargs)
tasks.append(coroutine)

async def run_tasks() -> list[Any]:
return await asyncio.gather(*tasks)

dtype = DataType._from_pydatatype(return_dtype)
outputs = asyncio.run(run_tasks())

return Series.from_pylist(outputs, dtype=dtype)._series


def call_func_with_evaluated_exprs(
fn: Callable[..., Any],
return_dtype: PyDataType,
Expand Down
156 changes: 99 additions & 57 deletions src/daft-dsl/src/python_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,71 +128,113 @@ impl RowWisePyFn {
);
}

let is_async: bool = Python::with_gil(|py| {
py.import(pyo3::intern!(py, "asyncio"))?
.getattr(pyo3::intern!(py, "iscoroutinefunction"))?
.call1((self.inner.clone().unwrap().as_ref(),))?
.extract()
})?;
let py_return_type = PyDataType::from(self.return_dtype.clone());

let call_func_with_evaluated_exprs = pyo3::Python::with_gil(|py| {
Ok::<_, PyErr>(
py.import(pyo3::intern!(py, "daft.udf.row_wise"))?
.getattr(pyo3::intern!(py, "call_func_with_evaluated_exprs"))?
.unbind(),
)
})?;
if is_async {
Ok(pyo3::Python::with_gil(|py| {
let f = py
.import(pyo3::intern!(py, "daft.udf.row_wise"))?
.getattr(pyo3::intern!(py, "call_async_batch_with_evaluated_exprs"))?;

// To minimize gil contention, while also allowing parallelism, we chunk up the rows
// for now,its just based on the max of (512) and (num rows / (number of CPUs * 4))
// This may need additional tuning based on usage patterns
//
// Instead of running sequentially and acquiring the gil for each row, we instead parallelize based off the chunk size.
// Each chunk then acquires the gil.
// Since we're processing data in chunks, there's less thrashing of the gil than if we were to use `.par_iter().map(|row| {Python::with_gil(..)})`
let n_cpus =
std::thread::available_parallelism().expect("Failed to get available parallelism");

let chunk_size = (num_rows / (n_cpus.get() * 4)).clamp(1, 512);

let indices: Vec<usize> = (0..num_rows).collect();
let outputs = indices
.par_chunks(chunk_size)
.map(|chunk| {
Python::with_gil(|py| {
chunk
let mut evaluted_args = Vec::with_capacity(num_rows);
for i in 0..num_rows {
let args_for_row = args
.iter()
.map(|&i| {
let args_for_row = args
.iter()
.map(|a| {
let idx = if a.len() == 1 { 0 } else { i };
LiteralValue::get_from_series(a, idx)
})
.collect::<DaftResult<Vec<_>>>()?;

let py_args = args_for_row
.into_iter()
.map(|a| a.into_pyobject(py))
.collect::<PyResult<Vec<_>>>()?;

let result = call_func_with_evaluated_exprs.bind(py).call1((
self.inner.clone().unwrap().as_ref(),
py_return_type.clone(),
self.original_args.clone().unwrap().as_ref(),
py_args,
))?;

let result_series = result.extract::<PySeries>()?.series;
Ok(result_series)
.map(|a| {
let idx = if a.len() == 1 { 0 } else { i };
LiteralValue::get_from_series(a, idx)
})
.collect::<DaftResult<Vec<Series>>>()
.collect::<DaftResult<Vec<_>>>()?;
let py_args_for_row = args_for_row
.into_iter()
.map(|a| a.into_pyobject(py))
.collect::<PyResult<Vec<_>>>()?;
evaluted_args.push(py_args_for_row);
}

let res = f.call1((
self.inner.clone().unwrap().as_ref(),
py_return_type.clone(),
self.original_args.clone().unwrap().as_ref(),
evaluted_args,
))?;
let name = args[0].name();

let result_series = res.extract::<PySeries>()?.series;

Ok::<_, PyErr>(result_series.rename(name))
})?)
} else {
let call_func_with_evaluated_exprs = pyo3::Python::with_gil(|py| {
Ok::<_, PyErr>(
py.import(pyo3::intern!(py, "daft.udf.row_wise"))?
.getattr(pyo3::intern!(py, "call_func_with_evaluated_exprs"))?
.unbind(),
)
})?;

// To minimize gil contention, while also allowing parallelism, we chunk up the rows
// for now,its just based on the max of (512) and (num rows / (number of CPUs * 4))
// This may need additional tuning based on usage patterns
//
// Instead of running sequentially and acquiring the gil for each row, we instead parallelize based off the chunk size.
// Each chunk then acquires the gil.
// Since we're processing data in chunks, there's less thrashing of the gil than if we were to use `.par_iter().map(|row| {Python::with_gil(..)})`
let n_cpus =
std::thread::available_parallelism().expect("Failed to get available parallelism");

let chunk_size = (num_rows / (n_cpus.get() * 4)).clamp(1, 512);

let indices: Vec<usize> = (0..num_rows).collect();
let outputs = indices
.par_chunks(chunk_size)
.map(|chunk| {
Python::with_gil(|py| {
chunk
.iter()
.map(|&i| {
let args_for_row = args
.iter()
.map(|a| {
let idx = if a.len() == 1 { 0 } else { i };
LiteralValue::get_from_series(a, idx)
})
.collect::<DaftResult<Vec<_>>>()?;

let py_args = args_for_row
.into_iter()
.map(|a| a.into_pyobject(py))
.collect::<PyResult<Vec<_>>>()?;

let result = call_func_with_evaluated_exprs.bind(py).call1((
self.inner.clone().unwrap().as_ref(),
py_return_type.clone(),
self.original_args.clone().unwrap().as_ref(),
py_args,
))?;

let result_series = result.extract::<PySeries>()?.series;
Ok(result_series)
})
.collect::<DaftResult<Vec<Series>>>()
})
})
})
.collect::<DaftResult<Vec<Vec<Series>>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
.collect::<DaftResult<Vec<Vec<Series>>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();

let outputs_ref = outputs.iter().collect::<Vec<_>>();
let outputs_ref = outputs.iter().collect::<Vec<_>>();

let name = args[0].name();
let name = args[0].name();

Ok(Series::concat(&outputs_ref)?.rename(name))
Ok(Series::concat(&outputs_ref)?.rename(name))
}
}
}
13 changes: 13 additions & 0 deletions tests/udf/test_row_wise_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,16 @@ def my_stringify_and_sum_repeat(a: int, b: int, repeat: int = 1) -> str:

dynamic_repeat_df = df.select(my_stringify_and_sum_repeat(col("x"), col("y"), repeat=col("x")))
assert dynamic_repeat_df.to_pydict() == {"x": ["5", "77", "999"]}


def test_row_wise_async_udf():
import asyncio

@daft.func
async def my_async_stringify_and_sum(a: int, b: int) -> str:
await asyncio.sleep(0.1)
return f"{a + b}"

df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]})
async_df = df.select(my_async_stringify_and_sum(col("x"), col("y")))
assert async_df.to_pydict() == {"x": ["5", "7", "9"]}
Loading