Skip to content

Commit cae1b42

Browse files
committed
[FEAT]: Support intersect all and except distinct/all in DataFrame
1 parent ae74c10 commit cae1b42

File tree

13 files changed

+643
-103
lines changed

13 files changed

+643
-103
lines changed

daft/daft/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,6 +1634,7 @@ class LogicalPlanBuilder:
16341634
) -> LogicalPlanBuilder: ...
16351635
def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ...
16361636
def intersect(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ...
1637+
def except_(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ...
16371638
def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder: ...
16381639
def table_write(
16391640
self,

daft/dataframe/dataframe.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2542,6 +2542,94 @@ def intersect(self, other: "DataFrame") -> "DataFrame":
25422542
builder = self._builder.intersect(other._builder)
25432543
return DataFrame(builder)
25442544

2545+
@DataframePublicAPI
2546+
def intersect_all(self, other: "DataFrame") -> "DataFrame":
2547+
"""Returns the intersection of two DataFrames, including duplicates.
2548+
2549+
Example:
2550+
>>> import daft
2551+
>>> df1 = daft.from_pydict({"a": [1, 2, 2], "b": [4, 6, 6]})
2552+
>>> df2 = daft.from_pydict({"a": [1, 1, 2, 2], "b": [4, 4, 6, 6]})
2553+
>>> df1.intersect_all(df2).collect()
2554+
╭───────┬───────╮
2555+
│ a ┆ b │
2556+
│ --- ┆ --- │
2557+
│ Int64 ┆ Int64 │
2558+
╞═══════╪═══════╡
2559+
│ 1 ┆ 4 │
2560+
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
2561+
│ 2 ┆ 6 │
2562+
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
2563+
│ 2 ┆ 6 │
2564+
╰───────┴───────╯
2565+
<BLANKLINE>
2566+
(Showing first 3 of 3 rows)
2567+
2568+
Args:
2569+
other (DataFrame): DataFrame to intersect with
2570+
2571+
Returns:
2572+
DataFrame: DataFrame with the intersection of the two DataFrames, including duplicates
2573+
"""
2574+
builder = self._builder.intersect_all(other._builder)
2575+
return DataFrame(builder)
2576+
2577+
@DataframePublicAPI
2578+
def except_distinct(self, other: "DataFrame") -> "DataFrame":
2579+
"""Returns the set difference of two DataFrames.
2580+
2581+
Example:
2582+
>>> import daft
2583+
>>> df1 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
2584+
>>> df2 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 8, 6]})
2585+
>>> df1.except_distinct(df2).collect()
2586+
╭───────┬───────╮
2587+
│ a ┆ b │
2588+
│ --- ┆ --- │
2589+
│ Int64 ┆ Int64 │
2590+
╞═══════╪═══════╡
2591+
│ 2 ┆ 5 │
2592+
╰───────┴───────╯
2593+
<BLANKLINE>
2594+
(Showing first 1 of 1 rows)
2595+
2596+
Args:
2597+
other (DataFrame): DataFrame to except with
2598+
2599+
Returns:
2600+
DataFrame: DataFrame with the set difference of the two DataFrames
2601+
"""
2602+
builder = self._builder.except_distinct(other._builder)
2603+
return DataFrame(builder)
2604+
2605+
@DataframePublicAPI
2606+
def except_all(self, other: "DataFrame") -> "DataFrame":
2607+
"""Returns the set difference of two DataFrames, considering duplicates.
2608+
2609+
Example:
2610+
>>> import daft
2611+
>>> df1 = daft.from_pydict({"a": [1, 1, 2, 2], "b": [4, 4, 6, 6]})
2612+
>>> df2 = daft.from_pydict({"a": [1, 2, 2], "b": [4, 6, 6]})
2613+
>>> df1.except_all(df2).collect()
2614+
╭───────┬───────╮
2615+
│ a ┆ b │
2616+
│ --- ┆ --- │
2617+
│ Int64 ┆ Int64 │
2618+
╞═══════╪═══════╡
2619+
│ 1 ┆ 4 │
2620+
╰───────┴───────╯
2621+
<BLANKLINE>
2622+
(Showing first 1 of 1 rows)
2623+
2624+
Args:
2625+
other (DataFrame): DataFrame to except with
2626+
2627+
Returns:
2628+
DataFrame: DataFrame with the set difference of the two DataFrames, considering duplicates
2629+
"""
2630+
builder = self._builder.except_all(other._builder)
2631+
return DataFrame(builder)
2632+
25452633
def _materialize_results(self) -> None:
25462634
"""Materializes the results of for this DataFrame and hold a pointer to the results."""
25472635
context = get_context()

daft/logical/builder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,18 @@ def intersect(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
279279
builder = self._builder.intersect(other._builder, False)
280280
return LogicalPlanBuilder(builder)
281281

282+
def intersect_all(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
283+
builder = self._builder.intersect(other._builder, True)
284+
return LogicalPlanBuilder(builder)
285+
286+
def except_distinct(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
287+
builder = self._builder.except_(other._builder, False)
288+
return LogicalPlanBuilder(builder)
289+
290+
def except_all(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
291+
builder = self._builder.except_(other._builder, True)
292+
return LogicalPlanBuilder(builder)
293+
282294
def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder:
283295
builder = self._builder.add_monotonically_increasing_id(column_name)
284296
return LogicalPlanBuilder(builder)

src/daft-core/src/array/ops/list.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::{iter::repeat, sync::Arc};
22

3-
use arrow2::offset::OffsetsBuffer;
3+
use arrow2::offset::{Offsets, OffsetsBuffer};
44
use common_error::DaftResult;
55
use indexmap::{
66
map::{raw_entry_v1::RawEntryMut, RawEntryApiV1},
@@ -255,6 +255,31 @@ fn list_sort_helper_fixed_size(
255255
.collect()
256256
}
257257

258+
fn general_list_fill_helper(element: &Series, num_array: &Int64Array) -> DaftResult<Vec<Series>> {
259+
let num_iter = create_iter(num_array, element.len());
260+
let mut result = vec![];
261+
let element_data = element.as_physical()?;
262+
for (row_index, num) in num_iter.enumerate() {
263+
let list_arr = if element.is_valid(row_index) {
264+
let mut list_growable = make_growable(
265+
element.name(),
266+
element.data_type(),
267+
vec![&element_data],
268+
false,
269+
num as usize,
270+
);
271+
for _ in 0..num {
272+
list_growable.extend(0, row_index, 1);
273+
}
274+
list_growable.build()?
275+
} else {
276+
Series::full_null(element.name(), element.data_type(), num as usize)
277+
};
278+
result.push(list_arr);
279+
}
280+
Ok(result)
281+
}
282+
258283
impl ListArray {
259284
pub fn value_counts(&self) -> DaftResult<MapArray> {
260285
struct IndexRef {
@@ -625,6 +650,25 @@ impl ListArray {
625650
self.validity().cloned(),
626651
))
627652
}
653+
654+
pub fn list_fill(elem: &Series, num_array: &Int64Array) -> DaftResult<Self> {
655+
let generated = general_list_fill_helper(elem, num_array)?;
656+
let generated_refs: Vec<&Series> = generated.iter().collect();
657+
let lengths = generated.iter().map(|arr| arr.len());
658+
let offsets = Offsets::try_from_lengths(lengths)?;
659+
let flat_child = if generated_refs.is_empty() {
660+
// when there's no output, we should create an empty series
661+
Series::empty(elem.name(), elem.data_type())
662+
} else {
663+
Series::concat(&generated_refs)?
664+
};
665+
Ok(Self::new(
666+
elem.field().to_list_field()?,
667+
flat_child,
668+
offsets.into(),
669+
None,
670+
))
671+
}
628672
}
629673

630674
impl FixedSizeListArray {

src/daft-core/src/series/ops/list.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ use common_error::{DaftError, DaftResult};
22
use daft_schema::field::Field;
33

44
use crate::{
5+
array::ListArray,
56
datatypes::{DataType, UInt64Array, Utf8Array},
6-
prelude::CountMode,
7+
prelude::{CountMode, Int64Array},
78
series::{IntoSeries, Series},
89
};
910

@@ -217,4 +218,14 @@ impl Series {
217218
))),
218219
}
219220
}
221+
222+
/// Given a series of data T, repeat each data T with num times to create a list, returns
223+
/// a series of repeated list.
224+
/// # Example
225+
/// ```txt
226+
/// repeat([1, 2, 3], [2, 0, 1]) --> [[1, 1], [], [3]]
227+
/// ```
228+
pub fn list_fill(&self, num: &Int64Array) -> DaftResult<Self> {
229+
ListArray::list_fill(self, num).map(|arr| arr.into_series())
230+
}
220231
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
use common_error::{DaftError, DaftResult};
2+
use daft_core::{
3+
datatypes::{DataType, Field},
4+
prelude::{Schema, Series},
5+
};
6+
use daft_dsl::{
7+
functions::{ScalarFunction, ScalarUDF},
8+
ExprRef,
9+
};
10+
use serde::{Deserialize, Serialize};
11+
12+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
13+
pub struct ListFill {}
14+
15+
#[typetag::serde]
16+
impl ScalarUDF for ListFill {
17+
fn as_any(&self) -> &dyn std::any::Any {
18+
self
19+
}
20+
21+
fn name(&self) -> &'static str {
22+
"fill"
23+
}
24+
25+
fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult<Field> {
26+
match inputs {
27+
[n, elem] => {
28+
let num_field = n.to_field(schema)?;
29+
let elem_field = elem.to_field(schema)?;
30+
if !num_field.dtype.is_integer() {
31+
return Err(DaftError::TypeError(format!(
32+
"Expected num field to be of numeric type, received: {}",
33+
num_field.dtype
34+
)));
35+
}
36+
elem_field.to_list_field()
37+
}
38+
_ => Err(DaftError::SchemaMismatch(format!(
39+
"Expected 2 input args, got {}",
40+
inputs.len()
41+
))),
42+
}
43+
}
44+
45+
fn evaluate(&self, inputs: &[Series]) -> DaftResult<Series> {
46+
match inputs {
47+
[num, elem] => {
48+
let num = num.cast(&DataType::Int64)?;
49+
let num_array = num.i64()?;
50+
elem.list_fill(num_array)
51+
}
52+
_ => Err(DaftError::ValueError(format!(
53+
"Expected 2 input args, got {}",
54+
inputs.len()
55+
))),
56+
}
57+
}
58+
}
59+
60+
#[must_use]
61+
pub fn list_fill(n: ExprRef, elem: ExprRef) -> ExprRef {
62+
ScalarFunction::new(ListFill {}, vec![n, elem]).into()
63+
}

src/daft-functions/src/list/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ mod count;
33
mod explode;
44
mod get;
55
mod join;
6+
mod list_fill;
67
mod max;
78
mod mean;
89
mod min;
@@ -17,6 +18,7 @@ pub use count::{list_count as count, ListCount};
1718
pub use explode::{explode, Explode};
1819
pub use get::{list_get as get, ListGet};
1920
pub use join::{list_join as join, ListJoin};
21+
pub use list_fill::list_fill;
2022
pub use max::{list_max as max, ListMax};
2123
pub use mean::{list_mean as mean, ListMean};
2224
pub use min::{list_min as min, ListMin};

src/daft-logical-plan/src/builder.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,9 +482,17 @@ impl LogicalPlanBuilder {
482482
pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
483483
let logical_plan: LogicalPlan =
484484
ops::Intersect::try_new(self.plan.clone(), other.plan.clone(), is_all)?
485-
.to_optimized_join()?;
485+
.to_logical_plan()?;
486486
Ok(self.with_new_plan(logical_plan))
487487
}
488+
489+
pub fn except(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
490+
let logical_plan: LogicalPlan =
491+
ops::Except::try_new(self.plan.clone(), other.plan.clone(), is_all)?
492+
.to_logical_plan()?;
493+
Ok(self.with_new_plan(logical_plan))
494+
}
495+
488496
pub fn union(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
489497
let logical_plan: LogicalPlan =
490498
ops::Union::try_new(self.plan.clone(), other.plan.clone(), is_all)?
@@ -861,6 +869,11 @@ impl PyLogicalPlanBuilder {
861869
Ok(self.builder.intersect(&other.builder, is_all)?.into())
862870
}
863871

872+
#[pyo3(name = "except_")]
873+
pub fn except(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
874+
Ok(self.builder.except(&other.builder, is_all)?.into())
875+
}
876+
864877
pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> PyResult<Self> {
865878
Ok(self
866879
.builder

src/daft-logical-plan/src/ops/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub use pivot::Pivot;
3030
pub use project::Project;
3131
pub use repartition::Repartition;
3232
pub use sample::Sample;
33-
pub use set_operations::{Intersect, Union};
33+
pub use set_operations::{Except, Intersect, Union};
3434
pub use sink::Sink;
3535
pub use sort::Sort;
3636
pub use source::Source;

0 commit comments

Comments
 (0)