Skip to content

Commit

Permalink
[BUG] Fix join errors with same key name joins (resolves #2649)
Browse files Browse the repository at this point in the history
The issue fixed here had a workaround previously - aliasing the
duplicate column name. This is not needed anymore as the aliasing
is performed under the hood, taking care of uniqueness of individual
column keys to avoid the duplicate issue.
  • Loading branch information
AnmolS authored and AnmolS committed Sep 26, 2024
1 parent dba931f commit 53aa100
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ log = {workspace = true}
pyo3 = {workspace = true, optional = true}
serde = {workspace = true, features = ["rc"]}
snafu = {workspace = true}
uuid = {version = "1", features = ["v4"]}

[dev-dependencies]
daft-dsl = {path = "../daft-dsl", features = ["test-utils"]}
Expand Down
37 changes: 35 additions & 2 deletions src/daft-plan/src/logical_ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use daft_dsl::{
};
use itertools::Itertools;
use snafu::ResultExt;
use uuid::Uuid;

use crate::{
logical_ops::Project,
Expand Down Expand Up @@ -54,10 +55,11 @@ impl Join {
join_type: JoinType,
join_strategy: Option<JoinStrategy>,
) -> logical_plan::Result<Self> {
let (unique_left_on, unique_right_on) = Self::rename_join_keys(left_on, right_on);
let (left_on, left_fields) =
resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?;
resolve_exprs(unique_left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, right_fields) =
resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?;
resolve_exprs(unique_right_on, &right.schema(), false).context(CreationSnafu)?;

for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] {
let on_schema = Schema::new(on_fields).context(CreationSnafu)?;
Expand Down Expand Up @@ -167,6 +169,37 @@ impl Join {
}
}

fn rename_join_keys(
left_exprs: Vec<Arc<Expr>>,
right_exprs: Vec<Arc<Expr>>,
) -> (Vec<Arc<Expr>>, Vec<Arc<Expr>>) {
left_exprs
.into_iter()
.enumerate()
.map(|(i, left_expr)| {
let right_expr = &right_exprs[i];

// Rename left_expr if it's different from right_expr
let renamed_left_expr = if left_expr != *right_expr {
let unique_id = Uuid::new_v4();
left_expr.alias(format!("{}", unique_id))
} else {
left_expr // Keep it unchanged if common
};

// Rename right_expr if it's different from left_expr
let renamed_right_expr = if *right_expr != renamed_left_expr {
let unique_id = Uuid::new_v4();
right_expr.clone().alias(format!("{}", unique_id))
} else {
right_expr.clone() // Keep it unchanged if common
};

(renamed_left_expr, renamed_right_expr)
})
.unzip()
}

pub fn multiline_display(&self) -> Vec<String> {
let mut res = vec![];
res.push(format!("Join: Type = {}", self.join_type));
Expand Down
11 changes: 11 additions & 0 deletions tests/dataframe/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def test_columns_after_join(make_df):
assert set(joined_df2.schema().column_names()) == set(["A", "B"])


def test_duplicate_join_keys_in_dataframe(make_df):
df1 = make_df({"A": [1, 2], "B": [2, 2]})

df2 = make_df({"A": [1, 2]})
joined_df1 = df1.join(df2, left_on=["A", "B"], right_on=["A", "A"])
joined_df2 = df1.join(df2, left_on=["B", "A"], right_on=["A", "A"])

assert set(joined_df1.schema().column_names()) == set(["A", "B"])
assert set(joined_df2.schema().column_names()) == set(["A", "B"])


@pytest.mark.parametrize("n_partitions", [1, 2, 4])
@pytest.mark.parametrize(
"join_strategy",
Expand Down

0 comments on commit 53aa100

Please sign in to comment.