From a26a94e77db04222209f830dbb165e7b88d34f2e Mon Sep 17 00:00:00 2001 From: Gabor Gevay Date: Fri, 20 Dec 2024 16:52:14 +0100 Subject: [PATCH] Eliminate some MRE::arity() calls --- src/compute-types/src/plan/lowering.rs | 5 +++-- src/expr/src/linear.rs | 2 ++ src/expr/src/relation.rs | 9 ++++----- src/sql/src/plan/lowering.rs | 5 ++++- src/transform/src/fusion/project.rs | 3 ++- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/compute-types/src/plan/lowering.rs b/src/compute-types/src/plan/lowering.rs index c8b9fe5356932..0c5cfb3f6b61c 100644 --- a/src/compute-types/src/plan/lowering.rs +++ b/src/compute-types/src/plan/lowering.rs @@ -548,8 +548,6 @@ impl Context { equivalences, implementation, } => { - let input_mapper = JoinInputMapper::new(inputs); - // Plan each of the join inputs independently. // The `plans` get surfaced upwards, and the `input_keys` should // be used as part of join planning / to validate the existing @@ -564,6 +562,9 @@ impl Context { input_keys.push(keys); } + let input_mapper = + JoinInputMapper::new_from_input_arities(input_arities.iter().copied()); + // Extract temporal predicates as joins cannot currently absorb them. let (plan, missing) = match implementation { IndexedFilter(_coll_id, _idx_id, key, _val) => { diff --git a/src/expr/src/linear.rs b/src/expr/src/linear.rs index 51195f264694f..c3ad2be55be33 100644 --- a/src/expr/src/linear.rs +++ b/src/expr/src/linear.rs @@ -329,6 +329,8 @@ impl MapFilterProject { let (mfp, expr) = Self::extract_from_expression(input); (mfp.project(outputs.iter().cloned()), expr) } + // TODO: The recursion is quadratic in the number of Map/Filter/Project operators due to + // this call to `arity()`. x => (Self::new(x.arity()), x), } } diff --git a/src/expr/src/relation.rs b/src/expr/src/relation.rs index 6a4ac8be2e08c..d1d5981e5b13c 100644 --- a/src/expr/src/relation.rs +++ b/src/expr/src/relation.rs @@ -1798,22 +1798,21 @@ impl MirRelationExpr { .unzip(); assert_eq!(keys_and_values.arity() - self.arity(), data.len()); self.let_in(id_gen, |_id_gen, get_keys| { + let get_keys_arity = get_keys.arity(); Ok(MirRelationExpr::join( vec![ // all the missing keys (with count 1) keys_and_values - .distinct_by((0..get_keys.arity()).collect()) + .distinct_by((0..get_keys_arity).collect()) .negate() .union(get_keys.clone().distinct()), // join with keys to get the correct counts get_keys.clone(), ], - (0..get_keys.arity()) - .map(|i| vec![(0, i), (1, i)]) - .collect(), + (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(), ) // get rid of the extra copies of columns from keys - .project((0..get_keys.arity()).collect()) + .project((0..get_keys_arity).collect()) // This join is logically equivalent to // `.map()`, but using a join allows for // potential predicate pushdown and elision in the diff --git a/src/sql/src/plan/lowering.rs b/src/sql/src/plan/lowering.rs index e6a214d265e60..b7e16f481acbd 100644 --- a/src/sql/src/plan/lowering.rs +++ b/src/sql/src/plan/lowering.rs @@ -1574,12 +1574,14 @@ impl HirScalarExpr { let inner_arity = get_inner.arity(); let mut total_arity = inner_arity; let mut join_inputs = vec![get_inner]; + let mut join_input_arities = vec![inner_arity]; for (expr, subquery) in subqueries.into_iter() { // Avoid lowering duplicated subqueries if !subquery_map.contains_key(&expr) { let subquery_arity = subquery.arity(); assert_eq!(subquery_arity, inner_arity + 1); join_inputs.push(subquery); + join_input_arities.push(subquery_arity); total_arity += subquery_arity; // Column with the value of the subquery @@ -1589,7 +1591,8 @@ impl HirScalarExpr { // Each subquery projects all the columns of the outer context (distinct_inner) // plus 1 column, containing the result of the subquery. Those columns must be // joined with the outer/main relation (get_inner). - let input_mapper = mz_expr::JoinInputMapper::new(&join_inputs); + let input_mapper = + mz_expr::JoinInputMapper::new_from_input_arities(join_input_arities); let equivalences = (0..inner_arity) .map(|col| { join_inputs diff --git a/src/transform/src/fusion/project.rs b/src/transform/src/fusion/project.rs index adde523c895ef..772ad23de9cd8 100644 --- a/src/transform/src/fusion/project.rs +++ b/src/transform/src/fusion/project.rs @@ -52,7 +52,8 @@ impl Project { *outputs = outputs.iter().map(|i| outputs2[*i]).collect(); **input = inner.take_dangerous(); } - if outputs.iter().enumerate().all(|(a, b)| a == *b) && outputs.len() == input.arity() { + let input_arity = input.arity(); + if outputs.iter().enumerate().all(|(a, b)| a == *b) && outputs.len() == input_arity { *relation = input.take_dangerous(); } }