Skip to content

Commit

Permalink
[BUG] Fix join errors with same key name joins (resolves #2649)
Browse files Browse the repository at this point in the history
The issue fixed here had a workaround previously - aliasing the
duplicate column name. This is not needed anymore as the aliasing
is performed under the hood, taking care of uniqueness of individual
column keys to avoid the duplicate issue.
  • Loading branch information
AnmolS authored and AnmolS committed Sep 27, 2024
1 parent dba931f commit 111995e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 2 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ log = {workspace = true}
pyo3 = {workspace = true, optional = true}
serde = {workspace = true, features = ["rc"]}
snafu = {workspace = true}
uuid = {version = "1", features = ["v4"]}

[dev-dependencies]
daft-dsl = {path = "../daft-dsl", features = ["test-utils"]}
Expand Down
26 changes: 24 additions & 2 deletions src/daft-plan/src/logical_ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use daft_dsl::{
};
use itertools::Itertools;
use snafu::ResultExt;
use uuid::Uuid;

use crate::{
logical_ops::Project,
Expand Down Expand Up @@ -54,10 +55,11 @@ impl Join {
join_type: JoinType,
join_strategy: Option<JoinStrategy>,
) -> logical_plan::Result<Self> {
let (unique_left_on, unique_right_on) = Self::rename_join_keys(left_on, right_on);
let (left_on, left_fields) =
resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?;
resolve_exprs(unique_left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, right_fields) =
resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?;
resolve_exprs(unique_right_on, &right.schema(), false).context(CreationSnafu)?;

for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] {
let on_schema = Schema::new(on_fields).context(CreationSnafu)?;
Expand Down Expand Up @@ -167,6 +169,26 @@ impl Join {
}
}

fn rename_join_keys(
left_exprs: Vec<Arc<Expr>>,
right_exprs: Vec<Arc<Expr>>,
) -> (Vec<Arc<Expr>>, Vec<Arc<Expr>>) {
left_exprs
.into_iter()
.zip(right_exprs)
.map(|(left_expr, right_expr)| {
if left_expr != right_expr {
let unique_id = Uuid::new_v4().to_string();
let renamed_left_expr = left_expr.alias(unique_id.clone());
let renamed_right_expr = right_expr.alias(unique_id);
(renamed_left_expr, renamed_right_expr)
} else {
(left_expr, right_expr)
}
})
.unzip()
}

pub fn multiline_display(&self) -> Vec<String> {
let mut res = vec![];
res.push(format!("Join: Type = {}", self.join_type));
Expand Down
11 changes: 11 additions & 0 deletions tests/dataframe/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def test_columns_after_join(make_df):
assert set(joined_df2.schema().column_names()) == set(["A", "B"])


def test_duplicate_join_keys_in_dataframe(make_df):
df1 = make_df({"A": [1, 2], "B": [2, 2]})

df2 = make_df({"A": [1, 2]})
joined_df1 = df1.join(df2, left_on=["A", "B"], right_on=["A", "A"])
joined_df2 = df1.join(df2, left_on=["B", "A"], right_on=["A", "A"])

assert set(joined_df1.schema().column_names()) == set(["A", "B"])
assert set(joined_df2.schema().column_names()) == set(["A", "B"])


@pytest.mark.parametrize("n_partitions", [1, 2, 4])
@pytest.mark.parametrize(
"join_strategy",
Expand Down

0 comments on commit 111995e

Please sign in to comment.