Skip to content

Commit 8356c94

Browse files
authored
Handle columns in with_new_exprs with a Join (#15055)
* handle columns in with_new_exprs with Join * test doesn't return result * take join from result * clippy * make test fallible * accept any pair of expression for new_on in with_new_exprs for Join * use with_capacity
1 parent c247b02 commit 8356c94

File tree

1 file changed

+118
-12
lines changed
  • datafusion/expr/src/logical_plan

1 file changed

+118
-12
lines changed

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 118 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,7 @@ impl LogicalPlan {
903903
let (left, right) = self.only_two_inputs(inputs)?;
904904
let schema = build_join_schema(left.schema(), right.schema(), join_type)?;
905905

906-
let equi_expr_count = on.len();
906+
let equi_expr_count = on.len() * 2;
907907
assert!(expr.len() >= equi_expr_count);
908908

909909
// Assume that the last expr, if any,
@@ -917,17 +917,16 @@ impl LogicalPlan {
917917
// The first part of expr is equi-exprs,
918918
// and the struct of each equi-expr is like `left-expr = right-expr`.
919919
assert_eq!(expr.len(), equi_expr_count);
920-
let new_on = expr.into_iter().map(|equi_expr| {
920+
let mut new_on = Vec::with_capacity(on.len());
921+
let mut iter = expr.into_iter();
922+
while let Some(left) = iter.next() {
923+
let Some(right) = iter.next() else {
924+
internal_err!("Expected a pair of expressions to construct the join on expression")?
925+
};
926+
921927
// SimplifyExpression rule may add alias to the equi_expr.
922-
let unalias_expr = equi_expr.clone().unalias();
923-
if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = unalias_expr {
924-
Ok((*left, *right))
925-
} else {
926-
internal_err!(
927-
"The front part expressions should be an binary equality expression, actual:{equi_expr}"
928-
)
929-
}
930-
}).collect::<Result<Vec<(Expr, Expr)>>>()?;
928+
new_on.push((left.unalias(), right.unalias()));
929+
}
931930

932931
Ok(LogicalPlan::Join(Join {
933932
left: Arc::new(left),
@@ -3780,7 +3779,8 @@ mod tests {
37803779
use crate::builder::LogicalTableSource;
37813780
use crate::logical_plan::table_scan;
37823781
use crate::{
3783-
col, exists, in_subquery, lit, placeholder, scalar_subquery, GroupingSet,
3782+
binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery,
3783+
GroupingSet,
37843784
};
37853785

37863786
use datafusion_common::tree_node::{
@@ -4632,4 +4632,110 @@ digraph {
46324632
let parameter_type = params.clone().get(placeholder_value).unwrap().clone();
46334633
assert_eq!(parameter_type, None);
46344634
}
4635+
4636+
#[test]
4637+
fn test_join_with_new_exprs() -> Result<()> {
4638+
fn create_test_join(
4639+
on: Vec<(Expr, Expr)>,
4640+
filter: Option<Expr>,
4641+
) -> Result<LogicalPlan> {
4642+
let schema = Schema::new(vec![
4643+
Field::new("a", DataType::Int32, false),
4644+
Field::new("b", DataType::Int32, false),
4645+
]);
4646+
4647+
let left_schema = DFSchema::try_from_qualified_schema("t1", &schema)?;
4648+
let right_schema = DFSchema::try_from_qualified_schema("t2", &schema)?;
4649+
4650+
Ok(LogicalPlan::Join(Join {
4651+
left: Arc::new(
4652+
table_scan(Some("t1"), left_schema.as_arrow(), None)?.build()?,
4653+
),
4654+
right: Arc::new(
4655+
table_scan(Some("t2"), right_schema.as_arrow(), None)?.build()?,
4656+
),
4657+
on,
4658+
filter,
4659+
join_type: JoinType::Inner,
4660+
join_constraint: JoinConstraint::On,
4661+
schema: Arc::new(left_schema.join(&right_schema)?),
4662+
null_equals_null: false,
4663+
}))
4664+
}
4665+
4666+
{
4667+
let join = create_test_join(vec![(col("t1.a"), (col("t2.a")))], None)?;
4668+
let LogicalPlan::Join(join) = join.with_new_exprs(
4669+
join.expressions(),
4670+
join.inputs().into_iter().cloned().collect(),
4671+
)?
4672+
else {
4673+
unreachable!()
4674+
};
4675+
assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]);
4676+
assert_eq!(join.filter, None);
4677+
}
4678+
4679+
{
4680+
let join = create_test_join(vec![], Some(col("t1.a").gt(col("t2.a"))))?;
4681+
let LogicalPlan::Join(join) = join.with_new_exprs(
4682+
join.expressions(),
4683+
join.inputs().into_iter().cloned().collect(),
4684+
)?
4685+
else {
4686+
unreachable!()
4687+
};
4688+
assert_eq!(join.on, vec![]);
4689+
assert_eq!(join.filter, Some(col("t1.a").gt(col("t2.a"))));
4690+
}
4691+
4692+
{
4693+
let join = create_test_join(
4694+
vec![(col("t1.a"), (col("t2.a")))],
4695+
Some(col("t1.b").gt(col("t2.b"))),
4696+
)?;
4697+
let LogicalPlan::Join(join) = join.with_new_exprs(
4698+
join.expressions(),
4699+
join.inputs().into_iter().cloned().collect(),
4700+
)?
4701+
else {
4702+
unreachable!()
4703+
};
4704+
assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]);
4705+
assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b"))));
4706+
}
4707+
4708+
{
4709+
let join = create_test_join(
4710+
vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))],
4711+
None,
4712+
)?;
4713+
let LogicalPlan::Join(join) = join.with_new_exprs(
4714+
vec![
4715+
binary_expr(col("t1.a"), Operator::Plus, lit(1)),
4716+
binary_expr(col("t2.a"), Operator::Plus, lit(2)),
4717+
col("t1.b"),
4718+
col("t2.b"),
4719+
lit(true),
4720+
],
4721+
join.inputs().into_iter().cloned().collect(),
4722+
)?
4723+
else {
4724+
unreachable!()
4725+
};
4726+
assert_eq!(
4727+
join.on,
4728+
vec![
4729+
(
4730+
binary_expr(col("t1.a"), Operator::Plus, lit(1)),
4731+
binary_expr(col("t2.a"), Operator::Plus, lit(2))
4732+
),
4733+
(col("t1.b"), (col("t2.b")))
4734+
]
4735+
);
4736+
assert_eq!(join.filter, Some(lit(true)));
4737+
}
4738+
4739+
Ok(())
4740+
}
46354741
}

0 commit comments

Comments
 (0)