Skip to content

Commit

Permalink
[BUG] Fix join errors with same key name joins (resolves #2649)
Browse files Browse the repository at this point in the history
The issue fixed here had a workaround previously - aliasing the
duplicate column name. This is not needed anymore as the aliasing
is performed under the hood, taking care of uniqueness of individual
column keys to avoid the duplicate issue.
  • Loading branch information
AnmolS authored and AnmolS committed Sep 21, 2024
1 parent dba931f commit 3fd90e2
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 2 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ log = {workspace = true}
pyo3 = {workspace = true, optional = true}
serde = {workspace = true, features = ["rc"]}
snafu = {workspace = true}
uuid = { version = "1", features = ["v4"] }

[dev-dependencies]
daft-dsl = {path = "../daft-dsl", features = ["test-utils"]}
Expand Down
35 changes: 33 additions & 2 deletions src/daft-plan/src/logical_ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use daft_dsl::{
};
use itertools::Itertools;
use snafu::ResultExt;
use uuid::Uuid;

use crate::{
logical_ops::Project,
Expand Down Expand Up @@ -54,10 +55,11 @@ impl Join {
join_type: JoinType,
join_strategy: Option<JoinStrategy>,
) -> logical_plan::Result<Self> {
let (unique_left_on, unique_right_on) = Self::process_expressions(left_on, right_on);
let (left_on, left_fields) =
resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?;
resolve_exprs(unique_left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, right_fields) =
resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?;
resolve_exprs(unique_right_on, &right.schema(), false).context(CreationSnafu)?;

for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] {
let on_schema = Schema::new(on_fields).context(CreationSnafu)?;
Expand Down Expand Up @@ -167,6 +169,35 @@ impl Join {
}
}

fn deduplicate_exprs(exprs: Vec<Arc<Expr>>) -> Vec<Arc<Expr>> {
let mut counts: HashMap<Arc<Expr>, usize> = HashMap::new();

exprs
.into_iter()
.map(|expr| {
let count = counts.entry(expr.clone()).or_insert(0);
*count += 1;

if *count == 1 {
expr // First occurrence, return the original expression
} else {
let unique_id = Uuid::new_v4();
expr.alias(format!("{}", unique_id)) // Append count for duplicates
}
})
.collect()
}

fn process_expressions(
left_on: Vec<Arc<Expr>>,
right_on: Vec<Arc<Expr>>,
) -> (Vec<Arc<Expr>>, Vec<Arc<Expr>>) {
let unique_left_on = Self::deduplicate_exprs(left_on);
let unique_right_on = Self::deduplicate_exprs(right_on);

(unique_left_on, unique_right_on)
}

pub fn multiline_display(&self) -> Vec<String> {
let mut res = vec![];
res.push(format!("Join: Type = {}", self.join_type));
Expand Down
9 changes: 9 additions & 0 deletions tests/dataframe/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def test_columns_after_join(make_df):
assert set(joined_df2.schema().column_names()) == set(["A", "B"])


def test_duplicate_join_keys_in_dataframe(make_df):
df1 = make_df({"A": [1, 2], "B": [2, 2]})

df2 = make_df({"A": [1, 2]})
joined_df = df1.join(df2, left_on=["A", "B"], right_on=["A", "A"])

assert set(joined_df.schema().column_names()) == set(["A", "B"])


@pytest.mark.parametrize("n_partitions", [1, 2, 4])
@pytest.mark.parametrize(
"join_strategy",
Expand Down

0 comments on commit 3fd90e2

Please sign in to comment.