@@ -903,7 +903,7 @@ impl LogicalPlan {
903903 let ( left, right) = self . only_two_inputs ( inputs) ?;
904904 let schema = build_join_schema ( left. schema ( ) , right. schema ( ) , join_type) ?;
905905
906- let equi_expr_count = on. len ( ) ;
906+ let equi_expr_count = on. len ( ) * 2 ;
907907 assert ! ( expr. len( ) >= equi_expr_count) ;
908908
909909 // Assume that the last expr, if any,
@@ -917,17 +917,16 @@ impl LogicalPlan {
917917 // The first part of expr is equi-exprs,
918918 // and the struct of each equi-expr is like `left-expr = right-expr`.
919919 assert_eq ! ( expr. len( ) , equi_expr_count) ;
920- let new_on = expr. into_iter ( ) . map ( |equi_expr| {
920+ let mut new_on = Vec :: with_capacity ( on. len ( ) ) ;
921+ let mut iter = expr. into_iter ( ) ;
922+ while let Some ( left) = iter. next ( ) {
923+ let Some ( right) = iter. next ( ) else {
924+ internal_err ! ( "Expected a pair of expressions to construct the join on expression" ) ?
925+ } ;
926+
921927 // SimplifyExpression rule may add alias to the equi_expr.
922- let unalias_expr = equi_expr. clone ( ) . unalias ( ) ;
923- if let Expr :: BinaryExpr ( BinaryExpr { left, op : Operator :: Eq , right } ) = unalias_expr {
924- Ok ( ( * left, * right) )
925- } else {
926- internal_err ! (
927- "The front part expressions should be an binary equality expression, actual:{equi_expr}"
928- )
929- }
930- } ) . collect :: < Result < Vec < ( Expr , Expr ) > > > ( ) ?;
928+ new_on. push ( ( left. unalias ( ) , right. unalias ( ) ) ) ;
929+ }
931930
932931 Ok ( LogicalPlan :: Join ( Join {
933932 left : Arc :: new ( left) ,
@@ -3780,7 +3779,8 @@ mod tests {
37803779 use crate :: builder:: LogicalTableSource ;
37813780 use crate :: logical_plan:: table_scan;
37823781 use crate :: {
3783- col, exists, in_subquery, lit, placeholder, scalar_subquery, GroupingSet ,
3782+ binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery,
3783+ GroupingSet ,
37843784 } ;
37853785
37863786 use datafusion_common:: tree_node:: {
@@ -4632,4 +4632,110 @@ digraph {
46324632 let parameter_type = params. clone ( ) . get ( placeholder_value) . unwrap ( ) . clone ( ) ;
46334633 assert_eq ! ( parameter_type, None ) ;
46344634 }
4635+
4636+ #[ test]
4637+ fn test_join_with_new_exprs ( ) -> Result < ( ) > {
4638+ fn create_test_join (
4639+ on : Vec < ( Expr , Expr ) > ,
4640+ filter : Option < Expr > ,
4641+ ) -> Result < LogicalPlan > {
4642+ let schema = Schema :: new ( vec ! [
4643+ Field :: new( "a" , DataType :: Int32 , false ) ,
4644+ Field :: new( "b" , DataType :: Int32 , false ) ,
4645+ ] ) ;
4646+
4647+ let left_schema = DFSchema :: try_from_qualified_schema ( "t1" , & schema) ?;
4648+ let right_schema = DFSchema :: try_from_qualified_schema ( "t2" , & schema) ?;
4649+
4650+ Ok ( LogicalPlan :: Join ( Join {
4651+ left : Arc :: new (
4652+ table_scan ( Some ( "t1" ) , left_schema. as_arrow ( ) , None ) ?. build ( ) ?,
4653+ ) ,
4654+ right : Arc :: new (
4655+ table_scan ( Some ( "t2" ) , right_schema. as_arrow ( ) , None ) ?. build ( ) ?,
4656+ ) ,
4657+ on,
4658+ filter,
4659+ join_type : JoinType :: Inner ,
4660+ join_constraint : JoinConstraint :: On ,
4661+ schema : Arc :: new ( left_schema. join ( & right_schema) ?) ,
4662+ null_equals_null : false ,
4663+ } ) )
4664+ }
4665+
4666+ {
4667+ let join = create_test_join ( vec ! [ ( col( "t1.a" ) , ( col( "t2.a" ) ) ) ] , None ) ?;
4668+ let LogicalPlan :: Join ( join) = join. with_new_exprs (
4669+ join. expressions ( ) ,
4670+ join. inputs ( ) . into_iter ( ) . cloned ( ) . collect ( ) ,
4671+ ) ?
4672+ else {
4673+ unreachable ! ( )
4674+ } ;
4675+ assert_eq ! ( join. on, vec![ ( col( "t1.a" ) , ( col( "t2.a" ) ) ) ] ) ;
4676+ assert_eq ! ( join. filter, None ) ;
4677+ }
4678+
4679+ {
4680+ let join = create_test_join ( vec ! [ ] , Some ( col ( "t1.a" ) . gt ( col ( "t2.a" ) ) ) ) ?;
4681+ let LogicalPlan :: Join ( join) = join. with_new_exprs (
4682+ join. expressions ( ) ,
4683+ join. inputs ( ) . into_iter ( ) . cloned ( ) . collect ( ) ,
4684+ ) ?
4685+ else {
4686+ unreachable ! ( )
4687+ } ;
4688+ assert_eq ! ( join. on, vec![ ] ) ;
4689+ assert_eq ! ( join. filter, Some ( col( "t1.a" ) . gt( col( "t2.a" ) ) ) ) ;
4690+ }
4691+
4692+ {
4693+ let join = create_test_join (
4694+ vec ! [ ( col( "t1.a" ) , ( col( "t2.a" ) ) ) ] ,
4695+ Some ( col ( "t1.b" ) . gt ( col ( "t2.b" ) ) ) ,
4696+ ) ?;
4697+ let LogicalPlan :: Join ( join) = join. with_new_exprs (
4698+ join. expressions ( ) ,
4699+ join. inputs ( ) . into_iter ( ) . cloned ( ) . collect ( ) ,
4700+ ) ?
4701+ else {
4702+ unreachable ! ( )
4703+ } ;
4704+ assert_eq ! ( join. on, vec![ ( col( "t1.a" ) , ( col( "t2.a" ) ) ) ] ) ;
4705+ assert_eq ! ( join. filter, Some ( col( "t1.b" ) . gt( col( "t2.b" ) ) ) ) ;
4706+ }
4707+
4708+ {
4709+ let join = create_test_join (
4710+ vec ! [ ( col( "t1.a" ) , ( col( "t2.a" ) ) ) , ( col( "t1.b" ) , ( col( "t2.b" ) ) ) ] ,
4711+ None ,
4712+ ) ?;
4713+ let LogicalPlan :: Join ( join) = join. with_new_exprs (
4714+ vec ! [
4715+ binary_expr( col( "t1.a" ) , Operator :: Plus , lit( 1 ) ) ,
4716+ binary_expr( col( "t2.a" ) , Operator :: Plus , lit( 2 ) ) ,
4717+ col( "t1.b" ) ,
4718+ col( "t2.b" ) ,
4719+ lit( true ) ,
4720+ ] ,
4721+ join. inputs ( ) . into_iter ( ) . cloned ( ) . collect ( ) ,
4722+ ) ?
4723+ else {
4724+ unreachable ! ( )
4725+ } ;
4726+ assert_eq ! (
4727+ join. on,
4728+ vec![
4729+ (
4730+ binary_expr( col( "t1.a" ) , Operator :: Plus , lit( 1 ) ) ,
4731+ binary_expr( col( "t2.a" ) , Operator :: Plus , lit( 2 ) )
4732+ ) ,
4733+ ( col( "t1.b" ) , ( col( "t2.b" ) ) )
4734+ ]
4735+ ) ;
4736+ assert_eq ! ( join. filter, Some ( lit( true ) ) ) ;
4737+ }
4738+
4739+ Ok ( ( ) )
4740+ }
46354741}
0 commit comments