From 111995e394d93e56f2104cd78c7967868d45852c Mon Sep 17 00:00:00 2001 From: AnmolS Date: Fri, 20 Sep 2024 22:45:46 -0700 Subject: [PATCH] [BUG] Fix join errors with same key name joins (resolves #2649) The issue fixed here had a workaround previously - aliasing the duplicate column name. This is not needed anymore as the aliasing is performed under the hood, taking care of uniqueness of individual column keys to avoid the duplicate issue. --- Cargo.lock | 1 + src/daft-plan/Cargo.toml | 1 + src/daft-plan/src/logical_ops/join.rs | 26 ++++++++++++++++++++++++-- tests/dataframe/test_joins.py | 11 +++++++++++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 90e4745ba8..6c30de5007 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2080,6 +2080,7 @@ dependencies = [ "serde", "snafu", "test-log", + "uuid 1.10.0", ] [[package]] diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index 1c5d224d89..98ddde173e 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -32,6 +32,7 @@ log = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true, features = ["rc"]} snafu = {workspace = true} +uuid = {version = "1", features = ["v4"]} [dev-dependencies] daft-dsl = {path = "../daft-dsl", features = ["test-utils"]} diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs index 2a68390066..a0e56665c5 100644 --- a/src/daft-plan/src/logical_ops/join.rs +++ b/src/daft-plan/src/logical_ops/join.rs @@ -13,6 +13,7 @@ use daft_dsl::{ }; use itertools::Itertools; use snafu::ResultExt; +use uuid::Uuid; use crate::{ logical_ops::Project, @@ -54,10 +55,11 @@ impl Join { join_type: JoinType, join_strategy: Option, ) -> logical_plan::Result { + let (unique_left_on, unique_right_on) = Self::rename_join_keys(left_on, right_on); let (left_on, left_fields) = - resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?; + resolve_exprs(unique_left_on, &left.schema(), false).context(CreationSnafu)?; let (right_on, right_fields) = - resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?; + resolve_exprs(unique_right_on, &right.schema(), false).context(CreationSnafu)?; for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] { let on_schema = Schema::new(on_fields).context(CreationSnafu)?; @@ -167,6 +169,26 @@ impl Join { } } + fn rename_join_keys( + left_exprs: Vec>, + right_exprs: Vec>, + ) -> (Vec>, Vec>) { + left_exprs + .into_iter() + .zip(right_exprs) + .map(|(left_expr, right_expr)| { + if left_expr != right_expr { + let unique_id = Uuid::new_v4().to_string(); + let renamed_left_expr = left_expr.alias(unique_id.clone()); + let renamed_right_expr = right_expr.alias(unique_id); + (renamed_left_expr, renamed_right_expr) + } else { + (left_expr, right_expr) + } + }) + .unzip() + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Join: Type = {}", self.join_type)); diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index b0bdbf9df4..177941dc49 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -53,6 +53,17 @@ def test_columns_after_join(make_df): assert set(joined_df2.schema().column_names()) == set(["A", "B"]) +def test_duplicate_join_keys_in_dataframe(make_df): + df1 = make_df({"A": [1, 2], "B": [2, 2]}) + + df2 = make_df({"A": [1, 2]}) + joined_df1 = df1.join(df2, left_on=["A", "B"], right_on=["A", "A"]) + joined_df2 = df1.join(df2, left_on=["B", "A"], right_on=["A", "A"]) + + assert set(joined_df1.schema().column_names()) == set(["A", "B"]) + assert set(joined_df2.schema().column_names()) == set(["A", "B"]) + + @pytest.mark.parametrize("n_partitions", [1, 2, 4]) @pytest.mark.parametrize( "join_strategy",