diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index b261776edc..09c32fb536 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -2038,6 +2038,33 @@ def explode(self) -> Expression: f = native.get_function_from_registry("explode") return Expression._from_pyexpr(f(self._expr)) + def list_append(self, other: Expression) -> Expression: + """Appends a value to each list in the column. + + Args: + other: A value or column of values to append to each list + + Returns: + Expression: An expression with the updated lists + + Examples: + >>> import daft + >>> df = daft.from_pydict({"a": [[1, 2], [3, 4, 5]], "b": [10, 11]}) + >>> df.with_column("combined", df["a"].list_append(df["b"])).show() + ╭─────────────┬───────┬───────────────╮ + │ a ┆ b ┆ combined │ + │ --- ┆ --- ┆ --- │ + │ List[Int64] ┆ Int64 ┆ List[Int64] │ + ╞═════════════╪═══════╪═══════════════╡ + │ [1, 2] ┆ 10 ┆ [1, 2, 10] │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ [3, 4, 5] ┆ 11 ┆ [3, 4, 5, 11] │ + ╰─────────────┴───────┴───────────────╯ + + (Showing first 2 of 2 rows) + """ + return self._eval_expressions("list_append", other) + SomeExpressionNamespace = TypeVar("SomeExpressionNamespace", bound="ExpressionNamespace") diff --git a/src/daft-functions-list/src/append.rs b/src/daft-functions-list/src/append.rs new file mode 100644 index 0000000000..bab393537f --- /dev/null +++ b/src/daft-functions-list/src/append.rs @@ -0,0 +1,72 @@ +use common_error::{DaftResult, ensure}; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + ExprRef, + functions::{FunctionArgs, ScalarUDF, scalar::ScalarFn}, +}; +use serde::{Deserialize, Serialize}; + +use crate::series::SeriesListExtension; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ListAppend; + +#[typetag::serde] +impl ScalarUDF for ListAppend { + fn name(&self) -> &'static str { + "list_append" + } + fn call(&self, inputs: daft_dsl::functions::FunctionArgs) -> DaftResult { + let input = inputs.required((0, "input"))?; + let other = inputs.required((1, "other"))?; + + // Normalize other input to same # of rows as input + let other = if other.len() == 1 { + &other.broadcast(input.len())? + } else { + other + }; + + input.list_append(other) + } + + fn get_return_field( + &self, + inputs: FunctionArgs, + schema: &Schema, + ) -> DaftResult { + ensure!( + inputs.len() == 2, + SchemaMismatch: "Expected 2 input args, got {}", + inputs.len() + ); + + let input = inputs.required((0, "input"))?.to_field(schema)?; + let other = inputs.required((1, "other"))?.to_field(schema)?; + + ensure!( + input.dtype.is_list() || input.dtype.is_fixed_size_list(), + "Input must be a list" + ); + + // The other input should have the same type as the list elements + let input_exploded = input.to_exploded_field()?; + + ensure!( + input_exploded.dtype == other.dtype, + TypeError: "Cannot append value of type {} to list of type {}", + other.dtype, + input_exploded.dtype + ); + + Ok(input_exploded.to_list_field()) + } +} + +#[must_use] +pub fn list_append(expr: ExprRef, other: ExprRef) -> ExprRef { + ScalarFn::builtin(ListAppend {}, vec![expr, other]).into() +} diff --git a/src/daft-functions-list/src/kernels.rs b/src/daft-functions-list/src/kernels.rs index 7f2c6d9998..f9698ffae2 100644 --- a/src/daft-functions-list/src/kernels.rs +++ b/src/daft-functions-list/src/kernels.rs @@ -776,9 +776,9 @@ fn join_arrow_list_of_utf8s( }) } -// Given an i64 array that may have either 1 or `self.len()` elements, create an iterator with -// `self.len()` elements. If there was originally 1 element, we repeat this element `self.len()` -// times, otherwise we simply take the original array. +/// Given an i64 array that may have either 1 or `self.len()` elements, create an iterator with +/// `self.len()` elements. If there was originally 1 element, we repeat this element `self.len()` +/// times, otherwise we simply take the original array. fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box + 'a> { match arr.len() { 1 => Box::new(repeat_n(arr.get(0).unwrap(), len)), diff --git a/src/daft-functions-list/src/lib.rs b/src/daft-functions-list/src/lib.rs index b6ab17addb..b979265bd0 100644 --- a/src/daft-functions-list/src/lib.rs +++ b/src/daft-functions-list/src/lib.rs @@ -1,3 +1,4 @@ +mod append; mod bool_and; mod bool_or; mod chunk; @@ -17,6 +18,7 @@ mod sort; mod sum; mod value_counts; +pub use append::{ListAppend, list_append as append}; pub use bool_and::{ListBoolAnd, list_bool_and as bool_and}; pub use bool_or::{ListBoolOr, list_bool_or as bool_or}; pub use chunk::{ListChunk, list_chunk as chunk}; @@ -46,6 +48,7 @@ pub struct ListFunctions; impl FunctionModule for ListFunctions { fn register(parent: &mut daft_dsl::functions::FunctionRegistry) { + parent.add_fn(ListAppend); parent.add_fn(ListBoolAnd); parent.add_fn(ListBoolOr); parent.add_fn(ListChunk); diff --git a/src/daft-functions-list/src/series.rs b/src/daft-functions-list/src/series.rs index 6de3db7390..fe829913fe 100644 --- a/src/daft-functions-list/src/series.rs +++ b/src/daft-functions-list/src/series.rs @@ -29,6 +29,7 @@ pub trait SeriesListExtension: Sized { fn list_count_distinct(&self) -> DaftResult; fn list_fill(&self, num: &Int64Array) -> DaftResult; fn list_distinct(&self) -> DaftResult; + fn list_append(&self, other: &Self) -> DaftResult; } impl SeriesListExtension for Series { @@ -341,4 +342,49 @@ impl SeriesListExtension for Series { Ok(list_array.into_series()) } + + fn list_append(&self, other: &Self) -> DaftResult { + let input = if let DataType::FixedSizeList(inner_type, _) = self.data_type() { + self.cast(&DataType::List(inner_type.clone()))? + } else { + self.clone() + }; + let input = input.list()?; + + let other = other.cast(input.child_data_type())?; + let mut growable = make_growable( + self.name(), + input.child_data_type(), + vec![&input.flat_child, &other], + false, + input.flat_child.len() + other.len(), + ); + + let offsets = input.offsets(); + let mut new_lengths = Vec::with_capacity(input.len()); + for i in 0..self.len() { + if input.is_valid(i) { + let start = *offsets.get(i).unwrap(); + let end = *offsets.get(i + 1).unwrap(); + let list_size = end - start; + growable.extend(0, start as usize, list_size as usize); + new_lengths.push((list_size + 1) as usize); + } else { + new_lengths.push(1); + } + + growable.extend(1, i, 1); + } + + let child_arr = growable.build()?; + let new_offsets = arrow2::offset::Offsets::try_from_lengths(new_lengths.into_iter())?; + let list_array = ListArray::new( + input.field.clone(), + child_arr, + new_offsets.into(), + None, // All outputs are valid because of the append + ); + + Ok(list_array.into_series()) + } } diff --git a/tests/recordbatch/list/test_list_append.py b/tests/recordbatch/list/test_list_append.py new file mode 100644 index 0000000000..9229c4e74d --- /dev/null +++ b/tests/recordbatch/list/test_list_append.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import pytest + +from daft import col, lit +from daft.datatype import DataType +from daft.recordbatch import MicroPartition + + +@pytest.fixture +def table(): + return MicroPartition.from_pydict( + { + "col": [None, None, [], [], ["a"], [None], ["a", "a"], ["a", None], ["a", None, "a"]], + "values": [None, "b", None, "c", "d", "e", "f", "g", "h"], + } + ) + + +@pytest.fixture +def fixed_table(): + table = MicroPartition.from_pydict( + { + "col": [["a", "a"], ["a", "a"], ["a", None], [None, None], None, [None, None], None, None], + "values": [None, "b", "c", "d", None, "e", "f", None], + } + ) + + fixed_dtype = DataType.fixed_size_list(DataType.string(), 2) + return table.eval_expression_list([col("col").cast(fixed_dtype), col("values")]) + + +def test_list_append_basic(table): + df = table.eval_expression_list([col("col").list_append(col("values"))]) + result = df.to_pydict() + + expected = [ + [None], + ["b"], + [None], + ["c"], + ["a", "d"], + [None, "e"], + ["a", "a", "f"], + ["a", None, "g"], + ["a", None, "a", "h"], + ] + assert result["col"] == expected + + +def test_list_append_with_literal(table): + df = table.eval_expression_list([col("col").list_append(lit("z"))]) + result = df.to_pydict() + + expected = [ + ["z"], + ["z"], + ["z"], + ["z"], + ["a", "z"], + [None, "z"], + ["a", "a", "z"], + ["a", None, "z"], + ["a", None, "a", "z"], + ] + assert result["col"] == expected + + +def test_fixed_list_append_basic(fixed_table): + df = fixed_table.eval_expression_list([col("col").list_append(col("values"))]) + result = df.to_pydict() + + expected = [ + ["a", "a", None], + ["a", "a", "b"], + ["a", None, "c"], + [None, None, "d"], + [None], + [None, None, "e"], + ["f"], + [None], + ] + assert result["col"] == expected + + +def test_fixed_list_append_literal(fixed_table): + df = fixed_table.eval_expression_list([col("col").list_append(lit("b"))]) + result = df.to_pydict() + + expected = [ + ["a", "a", "b"], + ["a", "a", "b"], + ["a", None, "b"], + [None, None, "b"], + ["b"], + [None, None, "b"], + ["b"], + ["b"], + ] + assert result["col"] == expected