Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4868,6 +4868,33 @@ def distinct(self) -> Expression:
"""
return self._eval_expressions("list_distinct")

def append(self, other: Expression) -> Expression:
"""Appends a value to each list in the column.

Args:
other: A value or column of values to append to each list

Returns:
Expression: An expression with the updated lists

Examples:
>>> import daft
>>> df = daft.from_pydict({"a": [[1, 2], [3, 4, 5]], "b": [10, 11]})
>>> df.with_column("combined", df["a"].list.append(df["b"])).show()
╭─────────────┬───────┬───────────────╮
│ a ┆ b ┆ combined │
│ --- ┆ --- ┆ --- │
│ List[Int64] ┆ Int64 ┆ List[Int64] │
╞═════════════╪═══════╪═══════════════╡
│ [1, 2] ┆ 10 ┆ [1, 2, 10] │
├╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [3, 4, 5] ┆ 11 ┆ [3, 4, 5, 11] │
╰─────────────┴───────┴───────────────╯
<BLANKLINE>
(Showing first 2 of 2 rows)
"""
return self._eval_expressions("list_append", other)

def unique(self) -> Expression:
"""Returns a list of distinct elements in each list, preserving order of first occurrence and ignoring nulls.

Expand Down
72 changes: 72 additions & 0 deletions src/daft-functions-list/src/append.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use common_error::{DaftResult, ensure};
use daft_core::{
prelude::{Field, Schema},
series::Series,
};
use daft_dsl::{
ExprRef,
functions::{FunctionArgs, ScalarUDF, scalar::ScalarFn},
};
use serde::{Deserialize, Serialize};

use crate::series::SeriesListExtension;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct ListAppend;

#[typetag::serde]
impl ScalarUDF for ListAppend {
fn name(&self) -> &'static str {
"list_append"
}
fn call(&self, inputs: daft_dsl::functions::FunctionArgs<Series>) -> DaftResult<Series> {
let input = inputs.required((0, "input"))?;
let other = inputs.required((1, "other"))?;

// Normalize other input to same # of rows as input
let other = if other.len() == 1 {
&other.broadcast(input.len())?
} else {
other
};

input.list_append(other)
}

fn get_return_field(
&self,
inputs: FunctionArgs<ExprRef>,
schema: &Schema,
) -> DaftResult<Field> {
ensure!(
inputs.len() == 2,
SchemaMismatch: "Expected 2 input args, got {}",
inputs.len()
);

let input = inputs.required((0, "input"))?.to_field(schema)?;
let other = inputs.required((1, "other"))?.to_field(schema)?;

ensure!(
input.dtype.is_list() || input.dtype.is_fixed_size_list(),
"Input must be a list"
);

// The other input should have the same type as the list elements
let input_exploded = input.to_exploded_field()?;

ensure!(
input_exploded.dtype == other.dtype,
TypeError: "Cannot append value of type {} to list of type {}",
other.dtype,
input_exploded.dtype
);

Ok(input_exploded.to_list_field())
}
}

#[must_use]
pub fn list_append(expr: ExprRef, other: ExprRef) -> ExprRef {
ScalarFn::builtin(ListAppend {}, vec![expr, other]).into()
}
6 changes: 3 additions & 3 deletions src/daft-functions-list/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -776,9 +776,9 @@ fn join_arrow_list_of_utf8s(
})
}

// Given an i64 array that may have either 1 or `self.len()` elements, create an iterator with
// `self.len()` elements. If there was originally 1 element, we repeat this element `self.len()`
// times, otherwise we simply take the original array.
/// Given an i64 array that may have either 1 or `self.len()` elements, create an iterator with
/// `self.len()` elements. If there was originally 1 element, we repeat this element `self.len()`
/// times, otherwise we simply take the original array.
fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box<dyn Iterator<Item = i64> + 'a> {
match arr.len() {
1 => Box::new(repeat_n(arr.get(0).unwrap(), len)),
Expand Down
3 changes: 3 additions & 0 deletions src/daft-functions-list/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod append;
mod bool_and;
mod bool_or;
mod chunk;
Expand All @@ -17,6 +18,7 @@ mod sort;
mod sum;
mod value_counts;

pub use append::{ListAppend, list_append as append};
pub use bool_and::{ListBoolAnd, list_bool_and as bool_and};
pub use bool_or::{ListBoolOr, list_bool_or as bool_or};
pub use chunk::{ListChunk, list_chunk as chunk};
Expand Down Expand Up @@ -46,6 +48,7 @@ pub struct ListFunctions;

impl FunctionModule for ListFunctions {
fn register(parent: &mut daft_dsl::functions::FunctionRegistry) {
parent.add_fn(ListAppend);
parent.add_fn(ListBoolAnd);
parent.add_fn(ListBoolOr);
parent.add_fn(ListChunk);
Expand Down
46 changes: 46 additions & 0 deletions src/daft-functions-list/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub trait SeriesListExtension: Sized {
fn list_count_distinct(&self) -> DaftResult<Self>;
fn list_fill(&self, num: &Int64Array) -> DaftResult<Self>;
fn list_distinct(&self) -> DaftResult<Self>;
fn list_append(&self, other: &Self) -> DaftResult<Self>;
}

impl SeriesListExtension for Series {
Expand Down Expand Up @@ -341,4 +342,49 @@ impl SeriesListExtension for Series {

Ok(list_array.into_series())
}

fn list_append(&self, other: &Self) -> DaftResult<Self> {
let input = if let DataType::FixedSizeList(inner_type, _) = self.data_type() {
self.cast(&DataType::List(inner_type.clone()))?
} else {
self.clone()
};
let input = input.list()?;

let other = other.cast(input.child_data_type())?;
let mut growable = make_growable(
self.name(),
input.child_data_type(),
vec![&input.flat_child, &other],
false,
input.flat_child.len() + other.len(),
);

let offsets = input.offsets();
let mut new_lengths = Vec::with_capacity(input.len());
for i in 0..self.len() {
if input.is_valid(i) {
let start = *offsets.get(i).unwrap();
let end = *offsets.get(i + 1).unwrap();
let list_size = end - start;
growable.extend(0, start as usize, list_size as usize);
new_lengths.push((list_size + 1) as usize);
} else {
new_lengths.push(1);
}

growable.extend(1, i, 1);
}

let child_arr = growable.build()?;
let new_offsets = arrow2::offset::Offsets::try_from_lengths(new_lengths.into_iter())?;
let list_array = ListArray::new(
input.field.clone(),
child_arr,
new_offsets.into(),
None, // All outputs are valid because of the append
);

Ok(list_array.into_series())
}
}
100 changes: 100 additions & 0 deletions tests/recordbatch/list/test_list_append.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

import pytest

from daft import col, lit
from daft.datatype import DataType
from daft.recordbatch import MicroPartition


@pytest.fixture
def table():
return MicroPartition.from_pydict(
{
"col": [None, None, [], [], ["a"], [None], ["a", "a"], ["a", None], ["a", None, "a"]],
"values": [None, "b", None, "c", "d", "e", "f", "g", "h"],
}
)


@pytest.fixture
def fixed_table():
table = MicroPartition.from_pydict(
{
"col": [["a", "a"], ["a", "a"], ["a", None], [None, None], None, [None, None], None, None],
"values": [None, "b", "c", "d", None, "e", "f", None],
}
)

fixed_dtype = DataType.fixed_size_list(DataType.string(), 2)
return table.eval_expression_list([col("col").cast(fixed_dtype), col("values")])


def test_list_append_basic(table):
df = table.eval_expression_list([col("col").list.append(col("values"))])
result = df.to_pydict()

expected = [
[None],
["b"],
[None],
["c"],
["a", "d"],
[None, "e"],
["a", "a", "f"],
["a", None, "g"],
["a", None, "a", "h"],
]
assert result["col"] == expected


def test_list_append_with_literal(table):
df = table.eval_expression_list([col("col").list.append(lit("z"))])
result = df.to_pydict()

expected = [
["z"],
["z"],
["z"],
["z"],
["a", "z"],
[None, "z"],
["a", "a", "z"],
["a", None, "z"],
["a", None, "a", "z"],
]
assert result["col"] == expected


def test_fixed_list_append_basic(fixed_table):
df = fixed_table.eval_expression_list([col("col").list.append(col("values"))])
result = df.to_pydict()

expected = [
["a", "a", None],
["a", "a", "b"],
["a", None, "c"],
[None, None, "d"],
[None],
[None, None, "e"],
["f"],
[None],
]
assert result["col"] == expected


def test_fixed_list_append_literal(fixed_table):
df = fixed_table.eval_expression_list([col("col").list.append(lit("b"))])
result = df.to_pydict()

expected = [
["a", "a", "b"],
["a", "a", "b"],
["a", None, "b"],
[None, None, "b"],
["b"],
[None, None, "b"],
["b"],
["b"],
]
assert result["col"] == expected
Loading