Skip to content

Commit feffe55

Browse files
authored
feat: .list.append Expression (#5159)
1 parent 9d3d359 commit feffe55

File tree

6 files changed

+251
-3
lines changed

6 files changed

+251
-3
lines changed

daft/expressions/expressions.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,33 @@ def explode(self) -> Expression:
20522052
f = native.get_function_from_registry("explode")
20532053
return Expression._from_pyexpr(f(self._expr))
20542054

2055+
def list_append(self, other: Expression) -> Expression:
2056+
"""Appends a value to each list in the column.
2057+
2058+
Args:
2059+
other: A value or column of values to append to each list
2060+
2061+
Returns:
2062+
Expression: An expression with the updated lists
2063+
2064+
Examples:
2065+
>>> import daft
2066+
>>> df = daft.from_pydict({"a": [[1, 2], [3, 4, 5]], "b": [10, 11]})
2067+
>>> df.with_column("combined", df["a"].list_append(df["b"])).show()
2068+
╭─────────────┬───────┬───────────────╮
2069+
│ a ┆ b ┆ combined │
2070+
│ --- ┆ --- ┆ --- │
2071+
│ List[Int64] ┆ Int64 ┆ List[Int64] │
2072+
╞═════════════╪═══════╪═══════════════╡
2073+
│ [1, 2] ┆ 10 ┆ [1, 2, 10] │
2074+
├╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
2075+
│ [3, 4, 5] ┆ 11 ┆ [3, 4, 5, 11] │
2076+
╰─────────────┴───────┴───────────────╯
2077+
<BLANKLINE>
2078+
(Showing first 2 of 2 rows)
2079+
"""
2080+
return self._eval_expressions("list_append", other)
2081+
20552082

20562083
SomeExpressionNamespace = TypeVar("SomeExpressionNamespace", bound="ExpressionNamespace")
20572084

src/daft-functions-list/src/append.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
use common_error::{DaftResult, ensure};
2+
use daft_core::{
3+
prelude::{Field, Schema},
4+
series::Series,
5+
};
6+
use daft_dsl::{
7+
ExprRef,
8+
functions::{FunctionArgs, ScalarUDF, scalar::ScalarFn},
9+
};
10+
use serde::{Deserialize, Serialize};
11+
12+
use crate::series::SeriesListExtension;
13+
14+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
15+
pub struct ListAppend;
16+
17+
#[typetag::serde]
18+
impl ScalarUDF for ListAppend {
19+
fn name(&self) -> &'static str {
20+
"list_append"
21+
}
22+
fn call(&self, inputs: daft_dsl::functions::FunctionArgs<Series>) -> DaftResult<Series> {
23+
let input = inputs.required((0, "input"))?;
24+
let other = inputs.required((1, "other"))?;
25+
26+
// Normalize other input to same # of rows as input
27+
let other = if other.len() == 1 {
28+
&other.broadcast(input.len())?
29+
} else {
30+
other
31+
};
32+
33+
input.list_append(other)
34+
}
35+
36+
fn get_return_field(
37+
&self,
38+
inputs: FunctionArgs<ExprRef>,
39+
schema: &Schema,
40+
) -> DaftResult<Field> {
41+
ensure!(
42+
inputs.len() == 2,
43+
SchemaMismatch: "Expected 2 input args, got {}",
44+
inputs.len()
45+
);
46+
47+
let input = inputs.required((0, "input"))?.to_field(schema)?;
48+
let other = inputs.required((1, "other"))?.to_field(schema)?;
49+
50+
ensure!(
51+
input.dtype.is_list() || input.dtype.is_fixed_size_list(),
52+
"Input must be a list"
53+
);
54+
55+
// The other input should have the same type as the list elements
56+
let input_exploded = input.to_exploded_field()?;
57+
58+
ensure!(
59+
input_exploded.dtype == other.dtype,
60+
TypeError: "Cannot append value of type {} to list of type {}",
61+
other.dtype,
62+
input_exploded.dtype
63+
);
64+
65+
Ok(input_exploded.to_list_field())
66+
}
67+
}
68+
69+
#[must_use]
70+
pub fn list_append(expr: ExprRef, other: ExprRef) -> ExprRef {
71+
ScalarFn::builtin(ListAppend {}, vec![expr, other]).into()
72+
}

src/daft-functions-list/src/kernels.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,9 +776,9 @@ fn join_arrow_list_of_utf8s(
776776
})
777777
}
778778

779-
// Given an i64 array that may have either 1 or `self.len()` elements, create an iterator with
780-
// `self.len()` elements. If there was originally 1 element, we repeat this element `self.len()`
781-
// times, otherwise we simply take the original array.
779+
/// Given an i64 array that may have either 1 or `self.len()` elements, create an iterator with
780+
/// `self.len()` elements. If there was originally 1 element, we repeat this element `self.len()`
781+
/// times, otherwise we simply take the original array.
782782
fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box<dyn Iterator<Item = i64> + 'a> {
783783
match arr.len() {
784784
1 => Box::new(repeat_n(arr.get(0).unwrap(), len)),

src/daft-functions-list/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod append;
12
mod bool_and;
23
mod bool_or;
34
mod chunk;
@@ -17,6 +18,7 @@ mod sort;
1718
mod sum;
1819
mod value_counts;
1920

21+
pub use append::{ListAppend, list_append as append};
2022
pub use bool_and::{ListBoolAnd, list_bool_and as bool_and};
2123
pub use bool_or::{ListBoolOr, list_bool_or as bool_or};
2224
pub use chunk::{ListChunk, list_chunk as chunk};
@@ -46,6 +48,7 @@ pub struct ListFunctions;
4648

4749
impl FunctionModule for ListFunctions {
4850
fn register(parent: &mut daft_dsl::functions::FunctionRegistry) {
51+
parent.add_fn(ListAppend);
4952
parent.add_fn(ListBoolAnd);
5053
parent.add_fn(ListBoolOr);
5154
parent.add_fn(ListChunk);

src/daft-functions-list/src/series.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub trait SeriesListExtension: Sized {
2929
fn list_count_distinct(&self) -> DaftResult<Self>;
3030
fn list_fill(&self, num: &Int64Array) -> DaftResult<Self>;
3131
fn list_distinct(&self) -> DaftResult<Self>;
32+
fn list_append(&self, other: &Self) -> DaftResult<Self>;
3233
}
3334

3435
impl SeriesListExtension for Series {
@@ -341,4 +342,49 @@ impl SeriesListExtension for Series {
341342

342343
Ok(list_array.into_series())
343344
}
345+
346+
fn list_append(&self, other: &Self) -> DaftResult<Self> {
347+
let input = if let DataType::FixedSizeList(inner_type, _) = self.data_type() {
348+
self.cast(&DataType::List(inner_type.clone()))?
349+
} else {
350+
self.clone()
351+
};
352+
let input = input.list()?;
353+
354+
let other = other.cast(input.child_data_type())?;
355+
let mut growable = make_growable(
356+
self.name(),
357+
input.child_data_type(),
358+
vec![&input.flat_child, &other],
359+
false,
360+
input.flat_child.len() + other.len(),
361+
);
362+
363+
let offsets = input.offsets();
364+
let mut new_lengths = Vec::with_capacity(input.len());
365+
for i in 0..self.len() {
366+
if input.is_valid(i) {
367+
let start = *offsets.get(i).unwrap();
368+
let end = *offsets.get(i + 1).unwrap();
369+
let list_size = end - start;
370+
growable.extend(0, start as usize, list_size as usize);
371+
new_lengths.push((list_size + 1) as usize);
372+
} else {
373+
new_lengths.push(1);
374+
}
375+
376+
growable.extend(1, i, 1);
377+
}
378+
379+
let child_arr = growable.build()?;
380+
let new_offsets = arrow2::offset::Offsets::try_from_lengths(new_lengths.into_iter())?;
381+
let list_array = ListArray::new(
382+
input.field.clone(),
383+
child_arr,
384+
new_offsets.into(),
385+
None, // All outputs are valid because of the append
386+
);
387+
388+
Ok(list_array.into_series())
389+
}
344390
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from daft import col, lit
6+
from daft.datatype import DataType
7+
from daft.recordbatch import MicroPartition
8+
9+
10+
@pytest.fixture
11+
def table():
12+
return MicroPartition.from_pydict(
13+
{
14+
"col": [None, None, [], [], ["a"], [None], ["a", "a"], ["a", None], ["a", None, "a"]],
15+
"values": [None, "b", None, "c", "d", "e", "f", "g", "h"],
16+
}
17+
)
18+
19+
20+
@pytest.fixture
21+
def fixed_table():
22+
table = MicroPartition.from_pydict(
23+
{
24+
"col": [["a", "a"], ["a", "a"], ["a", None], [None, None], None, [None, None], None, None],
25+
"values": [None, "b", "c", "d", None, "e", "f", None],
26+
}
27+
)
28+
29+
fixed_dtype = DataType.fixed_size_list(DataType.string(), 2)
30+
return table.eval_expression_list([col("col").cast(fixed_dtype), col("values")])
31+
32+
33+
def test_list_append_basic(table):
34+
df = table.eval_expression_list([col("col").list_append(col("values"))])
35+
result = df.to_pydict()
36+
37+
expected = [
38+
[None],
39+
["b"],
40+
[None],
41+
["c"],
42+
["a", "d"],
43+
[None, "e"],
44+
["a", "a", "f"],
45+
["a", None, "g"],
46+
["a", None, "a", "h"],
47+
]
48+
assert result["col"] == expected
49+
50+
51+
def test_list_append_with_literal(table):
52+
df = table.eval_expression_list([col("col").list_append(lit("z"))])
53+
result = df.to_pydict()
54+
55+
expected = [
56+
["z"],
57+
["z"],
58+
["z"],
59+
["z"],
60+
["a", "z"],
61+
[None, "z"],
62+
["a", "a", "z"],
63+
["a", None, "z"],
64+
["a", None, "a", "z"],
65+
]
66+
assert result["col"] == expected
67+
68+
69+
def test_fixed_list_append_basic(fixed_table):
70+
df = fixed_table.eval_expression_list([col("col").list_append(col("values"))])
71+
result = df.to_pydict()
72+
73+
expected = [
74+
["a", "a", None],
75+
["a", "a", "b"],
76+
["a", None, "c"],
77+
[None, None, "d"],
78+
[None],
79+
[None, None, "e"],
80+
["f"],
81+
[None],
82+
]
83+
assert result["col"] == expected
84+
85+
86+
def test_fixed_list_append_literal(fixed_table):
87+
df = fixed_table.eval_expression_list([col("col").list_append(lit("b"))])
88+
result = df.to_pydict()
89+
90+
expected = [
91+
["a", "a", "b"],
92+
["a", "a", "b"],
93+
["a", None, "b"],
94+
[None, None, "b"],
95+
["b"],
96+
[None, None, "b"],
97+
["b"],
98+
["b"],
99+
]
100+
assert result["col"] == expected

0 commit comments

Comments
 (0)