Skip to content

Commit

Permalink
[FEAT] agg_list support for list and struct types (#3019)
Browse files Browse the repository at this point in the history
Fix for #2981

---------

Co-authored-by: Andrew Gazelka <[email protected]>
  • Loading branch information
kevinzwang and andrewgazelka authored Oct 8, 2024
1 parent c2397bf commit 3f37a69
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 79 deletions.
7 changes: 7 additions & 0 deletions src/daft-core/src/array/fixed_size_list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ impl FixedSizeListArray {
self.validity.as_ref()
}

pub fn null_count(&self) -> usize {
match self.validity() {
None => 0,
Some(validity) => validity.unset_bits(),
}
}

pub fn concat(arrays: &[&Self]) -> DaftResult<Self> {
if arrays.is_empty() {
return Err(DaftError::ValueError(
Expand Down
136 changes: 57 additions & 79 deletions src/daft-core/src/array/ops/list_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,52 +10,72 @@ use crate::{
series::IntoSeries,
};

macro_rules! impl_daft_list_agg {
() => {
type Output = DaftResult<ListArray>;

fn list(&self) -> Self::Output {
let child_series = self.clone().into_series();
let offsets =
arrow2::offset::OffsetsBuffer::try_from(vec![0, child_series.len() as i64])?;
let list_field = self.field.to_list_field()?;
Ok(ListArray::new(list_field, child_series, offsets, None))
}

fn grouped_list(&self, groups: &GroupIndices) -> Self::Output {
let mut offsets = Vec::with_capacity(groups.len() + 1);

offsets.push(0);
for g in groups {
offsets.push(offsets.last().unwrap() + g.len() as i64);
}

let total_capacity = *offsets.last().unwrap();

let mut growable: Box<dyn Growable> = Box::new(Self::make_growable(
self.name(),
self.data_type(),
vec![self],
self.null_count() > 0,
total_capacity as usize,
));

for g in groups {
for idx in g {
growable.extend(0, *idx as usize, 1);
}
}
let list_field = self.field.to_list_field()?;

Ok(ListArray::new(
list_field,
growable.build()?,
arrow2::offset::OffsetsBuffer::try_from(offsets)?,
None,
))
}
};
}

impl<T> DaftListAggable for DataArray<T>
where
T: DaftArrowBackedType,
Self: IntoSeries,
Self: GrowableArray,
{
type Output = DaftResult<ListArray>;
fn list(&self) -> Self::Output {
let child_series = self.clone().into_series();
let offsets = arrow2::offset::OffsetsBuffer::try_from(vec![0, child_series.len() as i64])?;
let list_field = self.field.to_list_field()?;
Ok(ListArray::new(list_field, child_series, offsets, None))
}

fn grouped_list(&self, groups: &GroupIndices) -> Self::Output {
let mut offsets = Vec::with_capacity(groups.len() + 1);

offsets.push(0);
for g in groups {
offsets.push(offsets.last().unwrap() + g.len() as i64);
}
impl_daft_list_agg!();
}

let total_capacity = *offsets.last().unwrap();
impl DaftListAggable for ListArray {
impl_daft_list_agg!();
}

let mut growable: Box<dyn Growable> = Box::new(Self::make_growable(
self.name(),
self.data_type(),
vec![self],
self.data.null_count() > 0,
total_capacity as usize,
));
impl DaftListAggable for FixedSizeListArray {
impl_daft_list_agg!();
}

for g in groups {
for idx in g {
growable.extend(0, *idx as usize, 1);
}
}
let list_field = self.field.to_list_field()?;

Ok(ListArray::new(
list_field,
growable.build()?,
arrow2::offset::OffsetsBuffer::try_from(offsets)?,
None,
))
}
impl DaftListAggable for StructArray {
impl_daft_list_agg!();
}

#[cfg(feature = "python")]
Expand Down Expand Up @@ -95,45 +115,3 @@ impl DaftListAggable for crate::datatypes::PythonArray {
Self::new(self.field().clone().into(), Box::new(arrow_array))
}
}

impl DaftListAggable for ListArray {
type Output = DaftResult<Self>;

fn list(&self) -> Self::Output {
// TODO(FixedSizeList)
todo!("Requires new ListArrays for implementation")
}

fn grouped_list(&self, _groups: &GroupIndices) -> Self::Output {
// TODO(FixedSizeList)
todo!("Requires new ListArrays for implementation")
}
}

impl DaftListAggable for FixedSizeListArray {
type Output = DaftResult<ListArray>;

fn list(&self) -> Self::Output {
// TODO(FixedSizeList)
todo!("Requires new ListArrays for implementation")
}

fn grouped_list(&self, _groups: &GroupIndices) -> Self::Output {
// TODO(FixedSizeList)
todo!("Requires new ListArrays for implementation")
}
}

impl DaftListAggable for StructArray {
type Output = DaftResult<ListArray>;

fn list(&self) -> Self::Output {
// TODO(FixedSizeList)
todo!("Requires new ListArrays for implementation")
}

fn grouped_list(&self, _groups: &GroupIndices) -> Self::Output {
// TODO(FixedSizeList)
todo!("Requires new ListArrays for implementation")
}
}
7 changes: 7 additions & 0 deletions src/daft-core/src/array/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ impl StructArray {
self.validity.as_ref()
}

pub fn null_count(&self) -> usize {
match self.validity() {
None => 0,
Some(validity) => validity.unset_bits(),
}
}

pub fn concat(arrays: &[&Self]) -> DaftResult<Self> {
if arrays.is_empty() {
return Err(DaftError::ValueError(
Expand Down
78 changes: 78 additions & 0 deletions tests/table/test_table_aggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,32 @@ def test_global_pyobj_list_aggs() -> None:
assert result.to_pydict()["list"][0] == input


def test_global_list_list_aggs() -> None:
input = [[1], [2, 3, 4], [5, None], [], None]
table = MicroPartition.from_pydict({"input": input})
result = table.eval_expression_list([col("input").alias("list").agg_list()])
assert result.get_column("list").datatype() == DataType.list(DataType.list(DataType.int64()))
assert result.to_pydict()["list"][0] == input


def test_global_fixed_size_list_list_aggs() -> None:
input = Series.from_pylist([[1, 2], [3, 4], [5, None], None]).cast(DataType.fixed_size_list(DataType.int64(), 2))
table = MicroPartition.from_pydict({"input": input})
result = table.eval_expression_list([col("input").alias("list").agg_list()])
assert result.get_column("list").datatype() == DataType.list(DataType.fixed_size_list(DataType.int64(), 2))
assert result.to_pydict()["list"][0] == [[1, 2], [3, 4], [5, None], None]


def test_global_struct_list_aggs() -> None:
input = [{"a": 1, "b": 2}, {"a": 3, "b": None}, None]
table = MicroPartition.from_pydict({"input": input})
result = table.eval_expression_list([col("input").alias("list").agg_list()])
assert result.get_column("list").datatype() == DataType.list(
DataType.struct({"a": DataType.int64(), "b": DataType.int64()})
)
assert result.to_pydict()["list"][0] == input


@pytest.mark.parametrize(
"dtype", daft_nonnull_types + daft_null_types, ids=[f"{_}" for _ in daft_nonnull_types + daft_null_types]
)
Expand Down Expand Up @@ -701,6 +727,58 @@ def test_grouped_pyobj_list_aggs() -> None:
assert result.to_pydict() == {"groups": [1, 2, None], "list": expected_groups}


def test_grouped_list_list_aggs() -> None:
groups = [None, 1, None, 1, 2, 2]
input = [[1], [2, 3, 4], [5, None], None, [], [8, 9]]
expected_idx = [[1, 3], [4, 5], [0, 2]]

daft_table = MicroPartition.from_pydict({"groups": groups, "input": input})
daft_table = daft_table.eval_expression_list([col("groups"), col("input")])
result = daft_table.agg([col("input").alias("list").agg_list()], group_by=[col("groups")]).sort([col("groups")])
assert result.get_column("list").datatype() == DataType.list(DataType.list(DataType.int64()))

input_as_dtype = daft_table.get_column("input").to_pylist()
expected_groups = [[input_as_dtype[i] for i in group] for group in expected_idx]

assert result.to_pydict() == {"groups": [1, 2, None], "list": expected_groups}


def test_grouped_fixed_size_list_list_aggs() -> None:
groups = [None, 1, None, 1, 2, 2]
input = Series.from_pylist([[1, 2], [3, 4], [5, None], None, [6, 7], [8, 9]]).cast(
DataType.fixed_size_list(DataType.int64(), 2)
)
expected_idx = [[1, 3], [4, 5], [0, 2]]

daft_table = MicroPartition.from_pydict({"groups": groups, "input": input})
daft_table = daft_table.eval_expression_list([col("groups"), col("input")])
result = daft_table.agg([col("input").alias("list").agg_list()], group_by=[col("groups")]).sort([col("groups")])
assert result.get_column("list").datatype() == DataType.list(DataType.fixed_size_list(DataType.int64(), 2))

input_as_dtype = daft_table.get_column("input").to_pylist()
expected_groups = [[input_as_dtype[i] for i in group] for group in expected_idx]

assert result.to_pydict() == {"groups": [1, 2, None], "list": expected_groups}


def test_grouped_struct_list_aggs() -> None:
groups = [None, 1, None, 1, 2, 2]
input = [{"x": 1, "y": 2}, {"x": 3, "y": 4}, {"x": 5, "y": None}, None, {"x": 6, "y": 7}, {"x": 8, "y": 9}]
expected_idx = [[1, 3], [4, 5], [0, 2]]

daft_table = MicroPartition.from_pydict({"groups": groups, "input": input})
daft_table = daft_table.eval_expression_list([col("groups"), col("input")])
result = daft_table.agg([col("input").alias("list").agg_list()], group_by=[col("groups")]).sort([col("groups")])
assert result.get_column("list").datatype() == DataType.list(
DataType.struct({"x": DataType.int64(), "y": DataType.int64()})
)

input_as_dtype = daft_table.get_column("input").to_pylist()
expected_groups = [[input_as_dtype[i] for i in group] for group in expected_idx]

assert result.to_pydict() == {"groups": [1, 2, None], "list": expected_groups}


def test_list_aggs_empty() -> None:
daft_table = MicroPartition.from_pydict({"col_A": [], "col_B": []})
daft_table = daft_table.agg(
Expand Down

0 comments on commit 3f37a69

Please sign in to comment.