Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add left, right, and outer joins #2166

Merged
merged 14 commits into from
May 7, 2024
9 changes: 9 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,12 @@ To use a remote Ray cluster, run the following steps on the same operating syste
2. `ln -s daft wd/daft`: create a symbolic link from the Python module to the working directory
3. `make build-release`: an optimized build to ensure that the module is small enough to be successfully uploaded to Ray. Run this after modifying any Rust code in `src/`
4. `ray job submit --working-dir wd --address "http://<head_node_host>:8265" -- python script.py`: submit `wd/script.py` to be run on Ray

### Benchmarking

Benchmark tests are located in `tests/benchmarks`. If you would like to run benchmarks, make sure to first do `make build-release` instead of `make build` in order to compile an optimized build of Daft.

1. `pytest tests/benchmarks/[test_file.py] --benchmark-only`: Run all benchmarks in a file
2. `pytest tests/benchmarks/[test_file.py] -k [test_name] --benchmark-only`: Run a specific benchmark in a file

More information about writing and using benchmarks can be found on the [pytest-benchmark docs](https://pytest-benchmark.readthedocs.io/en/latest/).
7 changes: 5 additions & 2 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class JoinType(Enum):
Inner: int
Left: int
Right: int
Outer: int

@staticmethod
def from_join_type_str(join_type: str) -> JoinType:
Expand Down Expand Up @@ -1161,7 +1162,7 @@ class PyTable:
def pivot(
self, group_by: list[PyExpr], pivot_column: PyExpr, values_column: PyExpr, names: list[str]
) -> PyTable: ...
def hash_join(self, right: PyTable, left_on: list[PyExpr], right_on: list[PyExpr]) -> PyTable: ...
def hash_join(self, right: PyTable, left_on: list[PyExpr], right_on: list[PyExpr], how: JoinType) -> PyTable: ...
def sort_merge_join(
self, right: PyTable, left_on: list[PyExpr], right_on: list[PyExpr], is_sorted: bool
) -> PyTable: ...
Expand Down Expand Up @@ -1220,10 +1221,12 @@ class PyMicroPartition:
def sort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PyMicroPartition: ...
def argsort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PySeries: ...
def agg(self, to_agg: list[PyExpr], group_by: list[PyExpr]) -> PyMicroPartition: ...
def hash_join(
self, right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr], how: JoinType
) -> PyMicroPartition: ...
def pivot(
self, group_by: list[PyExpr], pivot_column: PyExpr, values_column: PyExpr, names: list[str]
) -> PyMicroPartition: ...
def hash_join(self, right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr]) -> PyMicroPartition: ...
def sort_merge_join(
self, right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr], is_sorted: bool
) -> PyMicroPartition: ...
Expand Down
11 changes: 7 additions & 4 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,9 +933,9 @@ def join(
Args:
other (DataFrame): the right DataFrame to join on.
on (Optional[Union[List[ColumnInputType], ColumnInputType]], optional): key or keys to join on [use if the keys on the left and right side match.]. Defaults to None.
left_on (Optional[Union[List[ColumnInputType], ColumnInputType]], optional): key or keys to join on left DataFrame.. Defaults to None.
left_on (Optional[Union[List[ColumnInputType], ColumnInputType]], optional): key or keys to join on left DataFrame. Defaults to None.
right_on (Optional[Union[List[ColumnInputType], ColumnInputType]], optional): key or keys to join on right DataFrame. Defaults to None.
how (str, optional): what type of join to performing, currently only `inner` is supported. Defaults to "inner".
how (str, optional): what type of join to perform; currently "inner", "left", "right", and "outer" are supported. Defaults to "inner".
strategy (Optional[str]): The join strategy (algorithm) to use; currently "hash", "sort_merge", "broadcast", and None are supported, where None
chooses the join strategy automatically during query optimization. The default is None.

Expand All @@ -955,10 +955,13 @@ def join(
left_on = on
right_on = on
join_type = JoinType.from_join_type_str(how)
if join_type != JoinType.Inner:
raise ValueError(f"Only inner joins are currently supported, but got: {how}")
join_strategy = JoinStrategy.from_join_strategy_str(strategy) if strategy is not None else None

if join_strategy == JoinStrategy.SortMerge and join_type != JoinType.Inner:
raise ValueError("Sort merge join only supports inner joins")
if join_strategy == JoinStrategy.Broadcast and join_type == JoinType.Outer:
raise ValueError("Broadcast join does not support outer joins")

left_exprs = self.__column_input_to_expression(tuple(left_on) if isinstance(left_on, list) else (left_on,))
right_exprs = self.__column_input_to_expression(tuple(right_on) if isinstance(right_on, list) else (right_on,))
builder = self._builder.join(
Expand Down
23 changes: 8 additions & 15 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,21 +194,14 @@ def join( # type: ignore[override]
how: JoinType = JoinType.Inner,
strategy: JoinStrategy | None = None,
) -> LogicalPlanBuilder:
if how == JoinType.Left:
raise NotImplementedError("Left join not implemented.")
elif how == JoinType.Right:
raise NotImplementedError("Right join not implemented.")
elif how == JoinType.Inner:
builder = self._builder.join(
right._builder,
[expr._expr for expr in left_on],
[expr._expr for expr in right_on],
how,
strategy,
)
return LogicalPlanBuilder(builder)
else:
raise NotImplementedError(f"{how} join not implemented.")
builder = self._builder.join(
right._builder,
[expr._expr for expr in left_on],
[expr._expr for expr in right_on],
how,
strategy,
)
return LogicalPlanBuilder(builder)

def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: # type: ignore[override]
builder = self._builder.concat(other._builder)
Expand Down
4 changes: 1 addition & 3 deletions daft/table/micropartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,6 @@ def hash_join(
right_on: ExpressionsProjection,
how: JoinType = JoinType.Inner,
) -> MicroPartition:
if how != JoinType.Inner:
raise NotImplementedError("TODO: [RUST] Implement Other Join types")
if len(left_on) != len(right_on):
raise ValueError(
f"Mismatch of number of join keys, left_on: {len(left_on)}, right_on: {len(right_on)}\nleft_on {left_on}\nright_on {right_on}"
Expand All @@ -270,7 +268,7 @@ def hash_join(
right_exprs = [e._expr for e in right_on]

return MicroPartition._from_pymicropartition(
self._micropartition.hash_join(right._micropartition, left_on=left_exprs, right_on=right_exprs)
self._micropartition.hash_join(right._micropartition, left_on=left_exprs, right_on=right_exprs, how=how)
)

def sort_merge_join(
Expand Down
6 changes: 3 additions & 3 deletions daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,6 @@ def hash_join(
right_on: ExpressionsProjection,
how: JoinType = JoinType.Inner,
) -> Table:
if how != JoinType.Inner:
raise NotImplementedError("TODO: [RUST] Implement Other Join types")
if len(left_on) != len(right_on):
raise ValueError(
f"Mismatch of number of join keys, left_on: {len(left_on)}, right_on: {len(right_on)}\nleft_on {left_on}\nright_on {right_on}"
Expand All @@ -340,7 +338,9 @@ def hash_join(
left_exprs = [e._expr for e in left_on]
right_exprs = [e._expr for e in right_on]

return Table._from_pytable(self._table.hash_join(right._table, left_on=left_exprs, right_on=right_exprs))
return Table._from_pytable(
self._table.hash_join(right._table, left_on=left_exprs, right_on=right_exprs, how=how)
)

def sort_merge_join(
self,
Expand Down
6 changes: 4 additions & 2 deletions src/daft-plan/src/join.rs → src/daft-core/src/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::{
str::FromStr,
};

use crate::impl_bincode_py_state_serialization;
use common_error::{DaftError, DaftResult};
use daft_core::impl_bincode_py_state_serialization;
#[cfg(feature = "python")]
use pyo3::{
exceptions::PyValueError, pyclass, pymethods, types::PyBytes, PyObject, PyResult, PyTypeInfo,
Expand All @@ -20,6 +20,7 @@ pub enum JoinType {
Inner,
Left,
Right,
Outer,
}

#[cfg(feature = "python")]
Expand All @@ -45,7 +46,7 @@ impl JoinType {
pub fn iterator() -> std::slice::Iter<'static, JoinType> {
use JoinType::*;

static JOIN_TYPES: [JoinType; 3] = [Inner, Left, Right];
static JOIN_TYPES: [JoinType; 4] = [Inner, Left, Right, Outer];
JOIN_TYPES.iter()
}
}
Expand All @@ -60,6 +61,7 @@ impl FromStr for JoinType {
"inner" => Ok(Inner),
"left" => Ok(Left),
"right" => Ok(Right),
"outer" => Ok(Outer),
_ => Err(DaftError::TypeError(format!(
"Join type {} is not supported; only the following types are supported: {:?}",
join_type,
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod count_mode;
pub mod datatypes;
#[cfg(feature = "python")]
pub mod ffi;
pub mod join;
pub mod kernels;
#[cfg(feature = "python")]
pub mod python;
Expand All @@ -18,6 +19,7 @@ use pyo3::prelude::*;

pub use count_mode::CountMode;
pub use datatypes::DataType;
pub use join::{JoinStrategy, JoinType};
pub use series::{IntoSeries, Series};

pub const VERSION: &str = env!("CARGO_PKG_VERSION");
Expand All @@ -33,6 +35,8 @@ pub const DAFT_BUILD_TYPE: &str = {
#[cfg(feature = "python")]
pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_class::<CountMode>()?;
parent.add_class::<JoinType>()?;
parent.add_class::<JoinStrategy>()?;

Ok(())
}
139 changes: 67 additions & 72 deletions src/daft-micropartition/src/ops/join.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,65 @@
use std::sync::Arc;

use common_error::DaftResult;
use daft_core::array::ops::DaftCompare;
use daft_core::{array::ops::DaftCompare, join::JoinType};
use daft_dsl::ExprRef;
use daft_io::IOStatsContext;
use daft_table::infer_join_schema;
use daft_table::{infer_join_schema, Table};

use crate::micropartition::MicroPartition;

use daft_stats::TruthValue;

impl MicroPartition {
pub fn hash_join(
fn join<F>(
&self,
right: &Self,
io_stats: Arc<IOStatsContext>,
left_on: &[ExprRef],
right_on: &[ExprRef],
) -> DaftResult<Self> {
let io_stats = IOStatsContext::new("MicroPartition::hash_join");
how: JoinType,
table_join: F,
) -> DaftResult<Self>
where
F: FnOnce(&Table, &Table, &[ExprRef], &[ExprRef], JoinType) -> DaftResult<Table>,
{
let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on)?;

if self.len() == 0 || right.len() == 0 {
return Ok(Self::empty(Some(join_schema.into())));
match (how, self.len(), right.len()) {
(JoinType::Inner, 0, _)
| (JoinType::Inner, _, 0)
| (JoinType::Left, 0, _)
| (JoinType::Right, _, 0)
| (JoinType::Outer, 0, 0) => {
return Ok(Self::empty(Some(join_schema.into())));
}
_ => {}
}

let tv = match (&self.statistics, &right.statistics) {
(_, None) => TruthValue::Maybe,
(None, _) => TruthValue::Maybe,
(Some(l), Some(r)) => {
let l_eval_stats = l.eval_expression_list(left_on, &self.schema)?;
let r_eval_stats = r.eval_expression_list(right_on, &right.schema)?;
let mut curr_tv = TruthValue::Maybe;
for (lc, rc) in l_eval_stats
.columns
.values()
.zip(r_eval_stats.columns.values())
{
if let TruthValue::False = lc.equal(rc)?.to_truth_value() {
curr_tv = TruthValue::False;
break;
if how == JoinType::Inner {
let tv = match (&self.statistics, &right.statistics) {
(_, None) => TruthValue::Maybe,
(None, _) => TruthValue::Maybe,
(Some(l), Some(r)) => {
let l_eval_stats = l.eval_expression_list(left_on, &self.schema)?;
let r_eval_stats = r.eval_expression_list(right_on, &right.schema)?;
let mut curr_tv = TruthValue::Maybe;
for (lc, rc) in l_eval_stats
.columns
.values()
.zip(r_eval_stats.columns.values())
{
if let TruthValue::False = lc.equal(rc)?.to_truth_value() {
curr_tv = TruthValue::False;
break;
}
}
curr_tv
}
curr_tv
};
if let TruthValue::False = tv {
return Ok(Self::empty(Some(join_schema.into())));
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved
}
};
if let TruthValue::False = tv {
return Ok(Self::empty(Some(join_schema.into())));
}

// TODO(Clark): Elide concatenations where possible by doing a chunk-aware local table join.
Expand All @@ -53,7 +69,7 @@ impl MicroPartition {
match (lt.as_slice(), rt.as_slice()) {
([], _) | (_, []) => Ok(Self::empty(Some(join_schema.into()))),
([lt], [rt]) => {
let joined_table = lt.hash_join(rt, left_on, right_on)?;
let joined_table = table_join(lt, rt, left_on, right_on, how)?;
Ok(MicroPartition::new_loaded(
join_schema.into(),
vec![joined_table].into(),
Expand All @@ -64,6 +80,18 @@ impl MicroPartition {
}
}

pub fn hash_join(
&self,
right: &Self,
left_on: &[ExprRef],
right_on: &[ExprRef],
how: JoinType,
) -> DaftResult<Self> {
let io_stats = IOStatsContext::new("MicroPartition::hash_join");

self.join(right, io_stats, left_on, right_on, how, Table::hash_join)
}

pub fn sort_merge_join(
&self,
right: &Self,
Expand All @@ -72,51 +100,18 @@ impl MicroPartition {
is_sorted: bool,
) -> DaftResult<Self> {
let io_stats = IOStatsContext::new("MicroPartition::sort_merge_join");
let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on)?;

if self.len() == 0 || right.len() == 0 {
return Ok(Self::empty(Some(join_schema.into())));
}

let tv = match (&self.statistics, &right.statistics) {
(_, None) => TruthValue::Maybe,
(None, _) => TruthValue::Maybe,
(Some(l), Some(r)) => {
let l_eval_stats = l.eval_expression_list(left_on, &self.schema)?;
let r_eval_stats = r.eval_expression_list(right_on, &right.schema)?;
let mut curr_tv = TruthValue::Maybe;
for (lc, rc) in l_eval_stats
.columns
.values()
.zip(r_eval_stats.columns.values())
{
if let TruthValue::False = lc.equal(rc)?.to_truth_value() {
curr_tv = TruthValue::False;
break;
}
}
curr_tv
}
};
if let TruthValue::False = tv {
return Ok(Self::empty(Some(join_schema.into())));
}

// TODO(Clark): Elide concatenations where possible by doing a chunk-aware local table join.
let lt = self.concat_or_get(io_stats.clone())?;
let rt = right.concat_or_get(io_stats)?;
let table_join =
|lt: &Table, rt: &Table, lo: &[ExprRef], ro: &[ExprRef], _how: JoinType| {
Table::sort_merge_join(lt, rt, lo, ro, is_sorted)
};

match (lt.as_slice(), rt.as_slice()) {
([], _) | (_, []) => Ok(Self::empty(Some(join_schema.into()))),
([lt], [rt]) => {
let joined_table = lt.sort_merge_join(rt, left_on, right_on, is_sorted)?;
Ok(MicroPartition::new_loaded(
join_schema.into(),
vec![joined_table].into(),
None,
))
}
_ => unreachable!(),
}
self.join(
right,
io_stats,
left_on,
right_on,
JoinType::Inner,
table_join,
)
}
}
Loading
Loading