Skip to content

Commit

Permalink
[FEAT] Implement Anti and Semi Join (#2379)
Browse files Browse the repository at this point in the history
closes: #2369
  • Loading branch information
samster25 authored Jun 18, 2024
1 parent e276984 commit 6a7d831
Show file tree
Hide file tree
Showing 12 changed files with 289 additions and 26 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class JoinType(Enum):
Left: int
Right: int
Outer: int
Semi: int
Anti: int

@staticmethod
def from_join_type_str(join_type: str) -> JoinType:
Expand Down
4 changes: 3 additions & 1 deletion daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,8 +1164,10 @@ def join(

if join_strategy == JoinStrategy.SortMerge and join_type != JoinType.Inner:
raise ValueError("Sort merge join only supports inner joins")
if join_strategy == JoinStrategy.Broadcast and join_type == JoinType.Outer:
elif join_strategy == JoinStrategy.Broadcast and join_type == JoinType.Outer:
raise ValueError("Broadcast join does not support outer joins")
elif join_strategy == JoinStrategy.Broadcast and join_type == JoinType.Anti:
raise ValueError("Broadcast join does not support Anti joins")

left_exprs = self.__column_input_to_expression(tuple(left_on) if isinstance(left_on, list) else (left_on,))
right_exprs = self.__column_input_to_expression(tuple(right_on) if isinstance(right_on, list) else (right_on,))
Expand Down
2 changes: 1 addition & 1 deletion daft/hudi/hudi_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
file=path,
file_format=file_format_config,
schema=self._schema._schema,
num_rows=record_count,
storage_config=self._storage_config,
num_rows=record_count,
size_bytes=size_bytes,
pushdowns=pushdowns,
partition_values=partition_values,
Expand Down
6 changes: 5 additions & 1 deletion src/daft-core/src/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ pub enum JoinType {
Left,
Right,
Outer,
Anti,
Semi,
}

#[cfg(feature = "python")]
Expand All @@ -46,7 +48,7 @@ impl JoinType {
pub fn iterator() -> std::slice::Iter<'static, JoinType> {
use JoinType::*;

static JOIN_TYPES: [JoinType; 4] = [Inner, Left, Right, Outer];
static JOIN_TYPES: [JoinType; 6] = [Inner, Left, Right, Outer, Anti, Semi];
JOIN_TYPES.iter()
}
}
Expand All @@ -62,6 +64,8 @@ impl FromStr for JoinType {
"left" => Ok(Left),
"right" => Ok(Right),
"outer" => Ok(Outer),
"anti" => Ok(Anti),
"semi" => Ok(Semi),
_ => Err(DaftError::TypeError(format!(
"Join type {} is not supported; only the following types are supported: {:?}",
join_type,
Expand Down
11 changes: 5 additions & 6 deletions src/daft-micropartition/src/ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ impl MicroPartition {
where
F: FnOnce(&Table, &Table, &[ExprRef], &[ExprRef], JoinType) -> DaftResult<Table>,
{
let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on)?;

let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on, how)?;
match (how, self.len(), right.len()) {
(JoinType::Inner, 0, _)
| (JoinType::Inner, _, 0)
| (JoinType::Left, 0, _)
| (JoinType::Right, _, 0)
| (JoinType::Outer, 0, 0) => {
return Ok(Self::empty(Some(join_schema.into())));
return Ok(Self::empty(Some(join_schema)));
}
_ => {}
}
Expand All @@ -58,7 +57,7 @@ impl MicroPartition {
}
};
if let TruthValue::False = tv {
return Ok(Self::empty(Some(join_schema.into())));
return Ok(Self::empty(Some(join_schema)));
}
}

Expand All @@ -67,11 +66,11 @@ impl MicroPartition {
let rt = right.concat_or_get(io_stats)?;

match (lt.as_slice(), rt.as_slice()) {
([], _) | (_, []) => Ok(Self::empty(Some(join_schema.into()))),
([], _) | (_, []) => Ok(Self::empty(Some(join_schema))),
([lt], [rt]) => {
let joined_table = table_join(lt, rt, left_on, right_on, how)?;
Ok(MicroPartition::new_loaded(
join_schema.into(),
join_schema,
vec![joined_table].into(),
None,
))
Expand Down
5 changes: 3 additions & 2 deletions src/daft-plan/src/logical_ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ impl Join {
.map(|(_, field)| field)
.cloned()
.chain(right.schema().fields.iter().filter_map(|(rname, rfield)| {
if left_join_keys.contains(rname.as_str())
&& right_join_keys.contains(rname.as_str())
if (left_join_keys.contains(rname.as_str())
&& right_join_keys.contains(rname.as_str()))
|| matches!(join_type, JoinType::Anti | JoinType::Semi)
{
right_input_mapping.insert(rname.clone(), rname.clone());
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ use std::{collections::HashMap, sync::Arc};

use common_error::DaftResult;

use daft_core::schema::Schema;
use daft_core::{schema::Schema, JoinType};
use daft_dsl::{col, optimization::replace_columns_with_expressions, Expr, ExprRef};
use indexmap::IndexSet;

use crate::{
logical_ops::{Aggregate, Pivot, Project, Source},
logical_ops::{Aggregate, Join, Pivot, Project, Source},
source_info::SourceInfo,
LogicalPlan, ResourceRequest,
};
Expand Down Expand Up @@ -478,6 +478,52 @@ impl PushDownProjection {
}
}

fn try_optimize_join(
&self,
join: &Join,
plan: Arc<LogicalPlan>,
) -> DaftResult<Transformed<Arc<LogicalPlan>>> {
// If this join prunes columns from its upstream,
// then explicitly create a projection to do so.
// this is the case for semi and anti joins.

if matches!(join.join_type, JoinType::Anti | JoinType::Semi) {
let required_cols = plan.required_columns();
let right_required_cols = required_cols
.get(1)
.expect("we expect 2 set of required columns for join");
let right_schema = join.right.schema();

if right_required_cols.len() < right_schema.fields.len() {
let new_subprojection: LogicalPlan = {
let pushdown_column_exprs = right_required_cols
.iter()
.map(|s| col(s.as_str()))
.collect::<Vec<_>>();

Project::try_new(
join.right.clone(),
pushdown_column_exprs,
Default::default(),
)?
.into()
};

let new_join = plan
.with_new_children(&[(join.left).clone(), new_subprojection.into()])
.arced();

Ok(self
.try_optimize(new_join.clone())?
.or(Transformed::Yes(new_join)))
} else {
Ok(Transformed::No(plan))
}
} else {
Ok(Transformed::No(plan))
}
}

fn try_optimize_pivot(
&self,
pivot: &Pivot,
Expand Down Expand Up @@ -524,6 +570,8 @@ impl OptimizerRule for PushDownProjection {
LogicalPlan::Aggregate(aggregation) => {
self.try_optimize_aggregation(aggregation, plan.clone())
}
// Joins also do column projection
LogicalPlan::Join(join) => self.try_optimize_join(join, plan.clone()),
// Pivots also do column projection
LogicalPlan::Pivot(pivot) => self.try_optimize_pivot(pivot, plan.clone()),
_ => Ok(Transformed::No(plan)),
Expand Down
10 changes: 10 additions & 0 deletions src/daft-plan/src/physical_planner/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,16 @@ pub(super) fn translate_single_logical_node(
"Broadcast join does not support outer joins.".to_string(),
));
}
(JoinType::Anti, _) => {
return Err(common_error::DaftError::ValueError(
"Broadcast join does not support anti joins.".to_string(),
));
}
(JoinType::Semi, _) => {
return Err(common_error::DaftError::ValueError(
"Broadcast join does not support semi joins.".to_string(),
));
}
};

if is_swapped {
Expand Down
43 changes: 43 additions & 0 deletions src/daft-table/src/ops/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,47 @@ impl Table {
}
Ok(probe_table)
}

pub fn to_probe_hash_map_without_idx(
&self,
) -> DaftResult<HashMap<IndexHash, (), IdentityBuildHasher>> {
let hashes = self.hash_rows()?;

const DEFAULT_SIZE: usize = 20;
let comparator = build_multi_array_is_equal(
self.columns.as_slice(),
self.columns.as_slice(),
true,
true,
)?;

let mut probe_table =
HashMap::<IndexHash, (), IdentityBuildHasher>::with_capacity_and_hasher(
DEFAULT_SIZE,
Default::default(),
);
// TODO(Sammy): Drop nulls using validity array if requested
for (i, h) in hashes.as_arrow().values_iter().enumerate() {
let entry = probe_table.raw_entry_mut().from_hash(*h, |other| {
(*h == other.hash) && {
let j = other.idx;
comparator(i, j as usize)
}
});
match entry {
RawEntryMut::Vacant(entry) => {
entry.insert_hashed_nocheck(
*h,
IndexHash {
idx: i as u64,
hash: *h,
},
(),
);
}
RawEntryMut::Occupied(_) => {}
}
}
Ok(probe_table)
}
}
87 changes: 83 additions & 4 deletions src/daft-table/src/ops/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use arrow2::{bitmap::MutableBitmap, types::IndexRange};
use daft_core::{
array::ops::{arrow2::comparison::build_multi_array_is_equal, full::FullNull},
datatypes::{BooleanArray, UInt64Array},
DataType, IntoSeries,
DataType, IntoSeries, JoinType,
};
use daft_dsl::ExprRef;

Expand All @@ -21,7 +21,13 @@ pub(super) fn hash_inner_join(
left_on: &[ExprRef],
right_on: &[ExprRef],
) -> DaftResult<Table> {
let join_schema = infer_join_schema(&left.schema, &right.schema, left_on, right_on)?;
let join_schema = infer_join_schema(
&left.schema,
&right.schema,
left_on,
right_on,
JoinType::Inner,
)?;
let lkeys = left.eval_expression_list(left_on)?;
let rkeys = right.eval_expression_list(right_on)?;

Expand Down Expand Up @@ -103,7 +109,13 @@ pub(super) fn hash_left_right_join(
right_on: &[ExprRef],
left_side: bool,
) -> DaftResult<Table> {
let join_schema = infer_join_schema(&left.schema, &right.schema, left_on, right_on)?;
let join_schema = infer_join_schema(
&left.schema,
&right.schema,
left_on,
right_on,
JoinType::Right,
)?;
let lkeys = left.eval_expression_list(left_on)?;
let rkeys = right.eval_expression_list(right_on)?;

Expand Down Expand Up @@ -212,13 +224,80 @@ pub(super) fn hash_left_right_join(
Table::new(join_schema, join_series)
}

pub(super) fn hash_semi_anti_join(
left: &Table,
right: &Table,
left_on: &[ExprRef],
right_on: &[ExprRef],
is_anti: bool,
) -> DaftResult<Table> {
let lkeys = left.eval_expression_list(left_on)?;
let rkeys = right.eval_expression_list(right_on)?;

let (lkeys, rkeys) = match_types_for_tables(&lkeys, &rkeys)?;

let lidx = if lkeys.columns.iter().any(|s| s.data_type().is_null())
|| rkeys.columns.iter().any(|s| s.data_type().is_null())
{
if is_anti {
// if we have a null column match, then all of the rows match for an anti join!
return Ok(left.clone());
} else {
UInt64Array::empty("left_indices", &DataType::UInt64).into_series()
}
} else {
let probe_table = rkeys.to_probe_hash_map_without_idx()?;

let l_hashes = lkeys.hash_rows()?;

let is_equal = build_multi_array_is_equal(
lkeys.columns.as_slice(),
rkeys.columns.as_slice(),
false,
false,
)?;
let rows = rkeys.len();

drop(lkeys);
drop(rkeys);

let mut left_idx = Vec::with_capacity(rows);
let is_semi = !is_anti;
for (l_idx, h) in l_hashes.as_arrow().values_iter().enumerate() {
let is_match = probe_table
.raw_entry()
.from_hash(*h, |other| {
*h == other.hash && {
let r_idx = other.idx as usize;
is_equal(l_idx, r_idx)
}
})
.is_some();
dbg!(l_idx);
if is_match == is_semi {
left_idx.push(l_idx as u64);
}
}

UInt64Array::from(("left_indices", left_idx)).into_series()
};

left.take(&lidx)
}

pub(super) fn hash_outer_join(
left: &Table,
right: &Table,
left_on: &[ExprRef],
right_on: &[ExprRef],
) -> DaftResult<Table> {
let join_schema = infer_join_schema(&left.schema, &right.schema, left_on, right_on)?;
let join_schema = infer_join_schema(
&left.schema,
&right.schema,
left_on,
right_on,
JoinType::Outer,
)?;
let lkeys = left.eval_expression_list(left_on)?;
let rkeys = right.eval_expression_list(right_on)?;

Expand Down
Loading

0 comments on commit 6a7d831

Please sign in to comment.