Skip to content

Commit

Permalink
Renaming join keys except when they are same key name column expressions
Browse files Browse the repository at this point in the history
Rename join keys except when they share the same key name
and are Column expression types; include original expression name
in the renamed expression.
  • Loading branch information
AnmolS authored and AnmolS committed Oct 3, 2024
1 parent 111995e commit 1f661b4
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 21 deletions.
90 changes: 70 additions & 20 deletions src/daft-plan/src/logical_ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
sync::Arc,
};

use common_error::DaftError;
use common_error::{DaftError, DaftResult};
use daft_core::prelude::*;
use daft_dsl::{
col,
Expand Down Expand Up @@ -55,15 +55,31 @@ impl Join {
join_type: JoinType,
join_strategy: Option<JoinStrategy>,
) -> logical_plan::Result<Self> {
let (unique_left_on, unique_right_on) = Self::rename_join_keys(left_on, right_on);
let (left_on, left_fields) =
resolve_exprs(unique_left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, right_fields) =
resolve_exprs(unique_right_on, &right.schema(), false).context(CreationSnafu)?;

for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] {
let on_schema = Schema::new(on_fields).context(CreationSnafu)?;
for (field, expr) in on_schema.fields.values().zip(on_exprs.iter()) {
let (left_on, _) = resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, _) =
resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?;

let (unique_left_on, unique_right_on) =
Self::rename_join_keys(left_on.clone(), right_on.clone());

let left_fields: Vec<Field> = unique_left_on
.iter()
.map(|e| e.to_field(&left.schema()))
.collect::<DaftResult<Vec<Field>>>()
.context(CreationSnafu)?;

let right_fields: Vec<Field> = unique_right_on
.iter()
.map(|e| e.to_field(&right.schema()))
.collect::<DaftResult<Vec<Field>>>()
.context(CreationSnafu)?;

for (on_exprs, on_fields) in [
(&unique_left_on, &left_fields),
(&unique_right_on, &right_fields),
] {
for (field, expr) in on_fields.iter().zip(on_exprs.iter()) {
// Null type check for both fields and expressions
if matches!(field.dtype, DataType::Null) {
return Err(DaftError::ValueError(format!(
"Can't join on null type expressions: {expr}"
Expand Down Expand Up @@ -169,23 +185,57 @@ impl Join {
}
}

/// Renames join keys for the given left and right expressions. This is required to
/// prevent errors when the join keys on the left and right expressions have the same key
/// name.
///
/// This function takes two vectors of expressions (`left_exprs` and `right_exprs`) and
/// checks for pairs of column expressions that differ. If both expressions in a pair
/// are column expressions and they are not identical, it generates a unique identifier
/// and renames both expressions by appending this identifier to their original names.
///
/// The function returns two vectors of expressions, where the renamed expressions are
/// substituted for the original expressions in the cases where renaming occurred.
///
/// # Parameters
/// - `left_exprs`: A vector of expressions from the left side of a join.
/// - `right_exprs`: A vector of expressions from the right side of a join.
///
/// # Returns
/// A tuple containing two vectors of expressions, one for the left side and one for the
/// right side, where expressions that needed to be renamed have been modified.
///
/// # Example
/// ```
/// let (renamed_left, renamed_right) = rename_join_keys(left_expressions, right_expressions);
/// ```
///
/// For more details, see [issue #2649](https://github.com/Eventual-Inc/Daft/issues/2649).

fn rename_join_keys(
left_exprs: Vec<Arc<Expr>>,
right_exprs: Vec<Arc<Expr>>,
) -> (Vec<Arc<Expr>>, Vec<Arc<Expr>>) {
left_exprs
.into_iter()
.zip(right_exprs)
.map(|(left_expr, right_expr)| {
if left_expr != right_expr {
let unique_id = Uuid::new_v4().to_string();
let renamed_left_expr = left_expr.alias(unique_id.clone());
let renamed_right_expr = right_expr.alias(unique_id);
(renamed_left_expr, renamed_right_expr)
} else {
(left_expr, right_expr)
}
})
.map(
|(left_expr, right_expr)| match (&*left_expr, &*right_expr) {
(Expr::Column(left_name), Expr::Column(right_name))
if left_name == right_name =>
{
(left_expr, right_expr)
}
_ => {
let unique_id = Uuid::new_v4().to_string();
let renamed_left_expr =
left_expr.alias(format!("{}_{}", left_expr.name(), unique_id));
let renamed_right_expr =
right_expr.alias(format!("{}_{}", right_expr.name(), unique_id));
(renamed_left_expr, renamed_right_expr)
}
},
)
.unzip()
}

Expand Down
2 changes: 1 addition & 1 deletion tests/dataframe/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_columns_after_join(make_df):
assert set(joined_df2.schema().column_names()) == set(["A", "B"])


def test_duplicate_join_keys_in_dataframe(make_df):
def test_rename_join_keys_in_dataframe(make_df):
df1 = make_df({"A": [1, 2], "B": [2, 2]})

df2 = make_df({"A": [1, 2]})
Expand Down

0 comments on commit 1f661b4

Please sign in to comment.