Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 146 additions & 10 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -906,28 +906,43 @@ impl LogicalPlan {
let equi_expr_count = on.len();
assert!(expr.len() >= equi_expr_count);

let col_pair_count =
expr.iter().filter(|e| matches!(e, Expr::Column(_))).count() / 2;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The filter can also be a column expression if it's a boolean column. And the on expression may not be a column expr, e.g. select * from t1 join t2 on t1.a+2 = t2.a+1 where t2.b.

I think we don't need to match expr types; we just extract them according to the format returned by apply_expressions(), where the first on.len() * 2 elements are the on-expression pairs, and the last one is the filter expression.


// Assume that the last expr, if any,
// is the filter_expr (non equality predicate from ON clause)
let filter_expr = if expr.len() > equi_expr_count {
let filter_expr = if expr.len() - col_pair_count > equi_expr_count {
expr.pop()
} else {
None
};

// The first part of expr is equi-exprs,
// and the struct of each equi-expr is like `left-expr = right-expr`.
assert_eq!(expr.len(), equi_expr_count);
let new_on = expr.into_iter().map(|equi_expr| {
assert_eq!(expr.len() - col_pair_count, equi_expr_count);
let mut new_on = Vec::new();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use with_capacity would be preferable.

 let mut new_on = Vec::with_capacity(on.len());

let mut iter = expr.into_iter();
while let Some(equi_expr) = iter.next() {
// SimplifyExpression rule may add alias to the equi_expr.
let unalias_expr = equi_expr.clone().unalias();
if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = unalias_expr {
Ok((*left, *right))
} else {
internal_err!(
"The front part expressions should be an binary equality expression, actual:{equi_expr}"
)
match unalias_expr {
Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Eq,
right,
}) => new_on.push((*left, *right)),
left @ Expr::Column(_) => {
let Some(right) = iter.next() else {
internal_err!("Expected a pair of columns to construct the join on expression")?
};

new_on.push((left, right));
}
_ => internal_err!(
"The front part expressions should be a binary equality expression or a column expression, actual:{equi_expr}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"The front part expressions should be a binary equality expression or a column expression, actual:{equi_expr}"
"The front part expressions should be a binary equality expression or a column expression, actual: {equi_expr}"

)?
}
}).collect::<Result<Vec<(Expr, Expr)>>>()?;
}

Ok(LogicalPlan::Join(Join {
left: Arc::new(left),
Expand Down Expand Up @@ -4630,4 +4645,125 @@ digraph {
let parameter_type = params.clone().get(placeholder_value).unwrap().clone();
assert_eq!(parameter_type, None);
}

#[test]
fn test_join_with_new_exprs() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fn test_join_with_new_exprs() {
fn test_join_with_new_exprs() -> Result<()> {

Make the function fallible so that many unwrap's can be replaced with ?

fn create_test_join(on: Vec<(Expr, Expr)>, filter: Option<Expr>) -> LogicalPlan {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);

let left_schema = DFSchema::try_from_qualified_schema("t1", &schema).unwrap();
let right_schema =
DFSchema::try_from_qualified_schema("t2", &schema).unwrap();

LogicalPlan::Join(Join {
left: Arc::new(
table_scan(Some("t1"), left_schema.as_arrow(), None)
.unwrap()
.build()
.unwrap(),
),
right: Arc::new(
table_scan(Some("t2"), right_schema.as_arrow(), None)
.unwrap()
.build()
.unwrap(),
),
on,
filter,
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
schema: Arc::new(left_schema.join(&right_schema).unwrap()),
null_equals_null: false,
})
}

{
let join = create_test_join(vec![(col("t1.a"), (col("t2.a")))], None);
let LogicalPlan::Join(join) = join
.with_new_exprs(
join.expressions(),
join.inputs().into_iter().cloned().collect(),
)
.unwrap()
else {
unreachable!()
};
assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]);
assert_eq!(join.filter, None);
}

{
let join = create_test_join(vec![], Some(col("t1.a").gt(col("t2.a"))));
let LogicalPlan::Join(join) = join
.with_new_exprs(
join.expressions(),
join.inputs().into_iter().cloned().collect(),
)
.unwrap()
else {
unreachable!()
};
assert_eq!(join.on, vec![]);
assert_eq!(join.filter, Some(col("t1.a").gt(col("t2.a"))));
}

{
let join = create_test_join(
vec![(col("t1.a"), (col("t2.a")))],
Some(col("t1.b").gt(col("t2.b"))),
);
let LogicalPlan::Join(join) = join
.with_new_exprs(
join.expressions(),
join.inputs().into_iter().cloned().collect(),
)
.unwrap()
else {
unreachable!()
};
assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]);
assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b"))));
}

{
let join = create_test_join(
vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))],
None,
);
let LogicalPlan::Join(join) = join
.with_new_exprs(
vec![
col("t1.a").eq(col("t2.a")),
col("t1.b"),
col("t2.b"),
lit(true),
],
join.inputs().into_iter().cloned().collect(),
)
.unwrap()
else {
unreachable!()
};
assert_eq!(
join.on,
vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))]
);
assert_eq!(join.filter, Some(lit(true)));
}

{
let join = create_test_join(
vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))],
None,
);
let res = join.with_new_exprs(
vec![col("t1.a").eq(col("t2.a")), col("t1.b")],
join.inputs().into_iter().cloned().collect(),
);
assert!(res.is_err());
}
}
}