Skip to content

Commit cb0beec

Browse files
delamarch3avantgardnerio
authored andcommitted
Handle columns in with_new_exprs with a Join (apache#15055)
* handle columns in with_new_exprs with Join * test doesn't return result * take join from result * clippy * make test fallible * accept any pair of expression for new_on in with_new_exprs for Join * use with_capacity (cherry picked from commit 8356c94)
1 parent 724e220 commit cb0beec

File tree

1 file changed

+148
-21
lines changed
  • datafusion/expr/src/logical_plan

1 file changed

+148
-21
lines changed

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 148 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ impl LogicalPlan {
890890
let (left, right) = self.only_two_inputs(inputs)?;
891891
let schema = build_join_schema(left.schema(), right.schema(), join_type)?;
892892

893-
let equi_expr_count = on.len();
893+
let equi_expr_count = on.len() * 2;
894894
assert!(expr.len() >= equi_expr_count);
895895

896896
// Assume that the last expr, if any,
@@ -904,17 +904,16 @@ impl LogicalPlan {
904904
// The first part of expr is equi-exprs,
905905
// and the struct of each equi-expr is like `left-expr = right-expr`.
906906
assert_eq!(expr.len(), equi_expr_count);
907-
let new_on = expr.into_iter().map(|equi_expr| {
907+
let mut new_on = Vec::with_capacity(on.len());
908+
let mut iter = expr.into_iter();
909+
while let Some(left) = iter.next() {
910+
let Some(right) = iter.next() else {
911+
internal_err!("Expected a pair of expressions to construct the join on expression")?
912+
};
913+
908914
// SimplifyExpression rule may add alias to the equi_expr.
909-
let unalias_expr = equi_expr.clone().unalias();
910-
if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = unalias_expr {
911-
Ok((*left, *right))
912-
} else {
913-
internal_err!(
914-
"The front part expressions should be an binary equality expression, actual:{equi_expr}"
915-
)
916-
}
917-
}).collect::<Result<Vec<(Expr, Expr)>>>()?;
915+
new_on.push((left.unalias(), right.unalias()));
916+
}
918917

919918
Ok(LogicalPlan::Join(Join {
920919
left: Arc::new(left),
@@ -3423,15 +3422,15 @@ pub enum Partitioning {
34233422
/// input output_name
34243423
/// ┌─────────┐ ┌─────────┐
34253424
/// │{{1,2}} │ │ 1 │
3426-
/// ├─────────┼─────►├─────────┤
3427-
/// │{{3}} │ │ 2 │
3428-
/// ├─────────┤ ├─────────┤
3429-
/// │{{4},{5}}│ │ 3 │
3430-
/// └─────────┘ ├─────────┤
3431-
/// │ 4 │
3432-
/// ├─────────┤
3433-
/// │ 5 │
3434-
/// └─────────┘
3425+
/// ├─────────┼─────►├─────────┤
3426+
/// │{{3}} │ │ 2 │
3427+
/// ├─────────┤ ├─────────┤
3428+
/// │{{4},{5}}│ │ 3 │
3429+
/// └─────────┘ ├─────────┤
3430+
/// │ 4 │
3431+
/// ├─────────┤
3432+
/// │ 5 │
3433+
/// └─────────┘
34353434
/// ```
34363435
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
34373436
pub struct ColumnUnnestList {
@@ -3516,7 +3515,8 @@ mod tests {
35163515
use crate::builder::LogicalTableSource;
35173516
use crate::logical_plan::table_scan;
35183517
use crate::{
3519-
col, exists, in_subquery, lit, placeholder, scalar_subquery, GroupingSet,
3518+
binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery,
3519+
GroupingSet,
35203520
};
35213521

35223522
use datafusion_common::tree_node::{
@@ -4347,4 +4347,131 @@ digraph {
43474347
plan.rewrite_with_subqueries(&mut rewriter).unwrap();
43484348
assert!(!rewriter.filter_found);
43494349
}
4350+
4351+
#[test]
4352+
fn test_with_unresolved_placeholders() {
4353+
let field_name = "id";
4354+
let placeholder_value = "$1";
4355+
let schema = Schema::new(vec![Field::new(field_name, DataType::Int32, false)]);
4356+
4357+
let plan = table_scan(TableReference::none(), &schema, None)
4358+
.unwrap()
4359+
.filter(col(field_name).eq(placeholder(placeholder_value)))
4360+
.unwrap()
4361+
.build()
4362+
.unwrap();
4363+
4364+
// Check that the placeholder parameters have not received a DataType.
4365+
let params = plan.get_parameter_types().unwrap();
4366+
assert_eq!(params.len(), 1);
4367+
4368+
let parameter_type = params.clone().get(placeholder_value).unwrap().clone();
4369+
assert_eq!(parameter_type, None);
4370+
}
4371+
4372+
#[test]
4373+
fn test_join_with_new_exprs() -> Result<()> {
4374+
fn create_test_join(
4375+
on: Vec<(Expr, Expr)>,
4376+
filter: Option<Expr>,
4377+
) -> Result<LogicalPlan> {
4378+
let schema = Schema::new(vec![
4379+
Field::new("a", DataType::Int32, false),
4380+
Field::new("b", DataType::Int32, false),
4381+
]);
4382+
4383+
let left_schema = DFSchema::try_from_qualified_schema("t1", &schema)?;
4384+
let right_schema = DFSchema::try_from_qualified_schema("t2", &schema)?;
4385+
4386+
Ok(LogicalPlan::Join(Join {
4387+
left: Arc::new(
4388+
table_scan(Some("t1"), left_schema.as_arrow(), None)?.build()?,
4389+
),
4390+
right: Arc::new(
4391+
table_scan(Some("t2"), right_schema.as_arrow(), None)?.build()?,
4392+
),
4393+
on,
4394+
filter,
4395+
join_type: JoinType::Inner,
4396+
join_constraint: JoinConstraint::On,
4397+
schema: Arc::new(left_schema.join(&right_schema)?),
4398+
null_equals_null: false,
4399+
}))
4400+
}
4401+
4402+
{
4403+
let join = create_test_join(vec![(col("t1.a"), (col("t2.a")))], None)?;
4404+
let LogicalPlan::Join(join) = join.with_new_exprs(
4405+
join.expressions(),
4406+
join.inputs().into_iter().cloned().collect(),
4407+
)?
4408+
else {
4409+
unreachable!()
4410+
};
4411+
assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]);
4412+
assert_eq!(join.filter, None);
4413+
}
4414+
4415+
{
4416+
let join = create_test_join(vec![], Some(col("t1.a").gt(col("t2.a"))))?;
4417+
let LogicalPlan::Join(join) = join.with_new_exprs(
4418+
join.expressions(),
4419+
join.inputs().into_iter().cloned().collect(),
4420+
)?
4421+
else {
4422+
unreachable!()
4423+
};
4424+
assert_eq!(join.on, vec![]);
4425+
assert_eq!(join.filter, Some(col("t1.a").gt(col("t2.a"))));
4426+
}
4427+
4428+
{
4429+
let join = create_test_join(
4430+
vec![(col("t1.a"), (col("t2.a")))],
4431+
Some(col("t1.b").gt(col("t2.b"))),
4432+
)?;
4433+
let LogicalPlan::Join(join) = join.with_new_exprs(
4434+
join.expressions(),
4435+
join.inputs().into_iter().cloned().collect(),
4436+
)?
4437+
else {
4438+
unreachable!()
4439+
};
4440+
assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]);
4441+
assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b"))));
4442+
}
4443+
4444+
{
4445+
let join = create_test_join(
4446+
vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))],
4447+
None,
4448+
)?;
4449+
let LogicalPlan::Join(join) = join.with_new_exprs(
4450+
vec![
4451+
binary_expr(col("t1.a"), Operator::Plus, lit(1)),
4452+
binary_expr(col("t2.a"), Operator::Plus, lit(2)),
4453+
col("t1.b"),
4454+
col("t2.b"),
4455+
lit(true),
4456+
],
4457+
join.inputs().into_iter().cloned().collect(),
4458+
)?
4459+
else {
4460+
unreachable!()
4461+
};
4462+
assert_eq!(
4463+
join.on,
4464+
vec![
4465+
(
4466+
binary_expr(col("t1.a"), Operator::Plus, lit(1)),
4467+
binary_expr(col("t2.a"), Operator::Plus, lit(2))
4468+
),
4469+
(col("t1.b"), (col("t2.b")))
4470+
]
4471+
);
4472+
assert_eq!(join.filter, Some(lit(true)));
4473+
}
4474+
4475+
Ok(())
4476+
}
43504477
}

0 commit comments

Comments
 (0)