Skip to content

Commit

Permalink
Avoid redundant pass-by-value in optimizer (#12262)
Browse files Browse the repository at this point in the history
  • Loading branch information
findepi committed Sep 2, 2024
1 parent 53de592 commit 447cb02
Show file tree
Hide file tree
Showing 19 changed files with 73 additions and 74 deletions.
4 changes: 2 additions & 2 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1740,8 +1740,8 @@ impl OptimizerConfig for SessionState {
self.execution_props.query_execution_start_time
}

fn alias_generator(&self) -> Arc<AliasGenerator> {
self.execution_props.alias_generator.clone()
fn alias_generator(&self) -> &Arc<AliasGenerator> {
&self.execution_props.alias_generator
}

fn options(&self) -> &ConfigOptions {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
}

/// merge inputs schema into a single schema.
pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {
pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema {
if inputs.len() == 1 {
inputs[0].schema().as_ref().clone()
} else {
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ fn expand_exprlist(input: &LogicalPlan, expr: Vec<Expr>) -> Result<Vec<Expr>> {
// If there is a REPLACE statement, replace that column with the given
// replace expression. Column name remains the same.
let replaced = if let Some(replace) = options.replace {
replace_columns(expanded, replace)?
replace_columns(expanded, &replace)?
} else {
expanded
};
Expand All @@ -95,7 +95,7 @@ fn expand_exprlist(input: &LogicalPlan, expr: Vec<Expr>) -> Result<Vec<Expr>> {
// If there is a REPLACE statement, replace that column with the given
// replace expression. Column name remains the same.
let replaced = if let Some(replace) = options.replace {
replace_columns(expanded, replace)?
replace_columns(expanded, &replace)?
} else {
expanded
};
Expand Down Expand Up @@ -139,7 +139,7 @@ fn expand_exprlist(input: &LogicalPlan, expr: Vec<Expr>) -> Result<Vec<Expr>> {
/// Multiple REPLACEs are also possible with comma separations.
fn replace_columns(
mut exprs: Vec<Expr>,
replace: PlannedReplaceSelectItem,
replace: &PlannedReplaceSelectItem,
) -> Result<Vec<Expr>> {
for expr in exprs.iter_mut() {
if let Expr::Column(Column { name, .. }) = expr {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/analyzer/function_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl ApplyFunctionRewrites {
) -> Result<Transformed<LogicalPlan>> {
// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let mut schema = merge_schema(plan.inputs());
let mut schema = merge_schema(&plan.inputs());

if let LogicalPlan::TableScan(ts) = &plan {
let source_schema = DFSchema::try_from_qualified_schema(
Expand Down
3 changes: 1 addition & 2 deletions datafusion/optimizer/src/analyzer/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
// under the License.

use std::ops::Deref;
use std::sync::Arc;

use crate::analyzer::check_plan;
use crate::utils::collect_subquery_cols;
Expand Down Expand Up @@ -246,7 +245,7 @@ fn check_aggregation_in_scalar_subquery(
if !agg.group_expr.is_empty() {
let correlated_exprs = get_correlated_expressions(inner_plan)?;
let inner_subquery_cols =
collect_subquery_cols(&correlated_exprs, Arc::clone(agg.input.schema()))?;
collect_subquery_cols(&correlated_exprs, agg.input.schema())?;
let mut group_columns = agg
.group_expr
.iter()
Expand Down
32 changes: 16 additions & 16 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ fn analyze_internal(
) -> Result<Transformed<LogicalPlan>> {
// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let mut schema = merge_schema(plan.inputs());
let mut schema = merge_schema(&plan.inputs());

if let LogicalPlan::TableScan(ts) = &plan {
let source_schema = DFSchema::try_from_qualified_schema(
Expand Down Expand Up @@ -544,12 +544,12 @@ fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarVa
/// Downstream code uses this signal to treat these values as *unbounded*.
fn coerce_scalar_range_aware(
target_type: &DataType,
value: ScalarValue,
value: &ScalarValue,
) -> Result<ScalarValue> {
coerce_scalar(target_type, &value).or_else(|err| {
coerce_scalar(target_type, value).or_else(|err| {
// If type coercion fails, check if the largest type in family works:
if let Some(largest_type) = get_widest_type_in_family(target_type) {
coerce_scalar(largest_type, &value).map_or_else(
coerce_scalar(largest_type, value).map_or_else(
|_| exec_err!("Cannot cast {value:?} to {target_type:?}"),
|_| ScalarValue::try_from(target_type),
)
Expand Down Expand Up @@ -578,11 +578,11 @@ fn coerce_frame_bound(
) -> Result<WindowFrameBound> {
match bound {
WindowFrameBound::Preceding(v) => {
coerce_scalar_range_aware(target_type, v).map(WindowFrameBound::Preceding)
coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
}
WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
WindowFrameBound::Following(v) => {
coerce_scalar_range_aware(target_type, v).map(WindowFrameBound::Following)
coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
}
}
}
Expand Down Expand Up @@ -1459,26 +1459,26 @@ mod test {

fn cast_helper(
case: Case,
case_when_type: DataType,
then_else_type: DataType,
case_when_type: &DataType,
then_else_type: &DataType,
schema: &DFSchemaRef,
) -> Case {
let expr = case
.expr
.map(|e| cast_if_not_same_type(e, &case_when_type, schema));
.map(|e| cast_if_not_same_type(e, case_when_type, schema));
let when_then_expr = case
.when_then_expr
.into_iter()
.map(|(when, then)| {
(
cast_if_not_same_type(when, &case_when_type, schema),
cast_if_not_same_type(then, &then_else_type, schema),
cast_if_not_same_type(when, case_when_type, schema),
cast_if_not_same_type(then, then_else_type, schema),
)
})
.collect::<Vec<_>>();
let else_expr = case
.else_expr
.map(|e| cast_if_not_same_type(e, &then_else_type, schema));
.map(|e| cast_if_not_same_type(e, then_else_type, schema));

Case {
expr,
Expand Down Expand Up @@ -1526,8 +1526,8 @@ mod test {
let then_else_common_type = DataType::Utf8;
let expected = cast_helper(
case.clone(),
case_when_common_type,
then_else_common_type,
&case_when_common_type,
&then_else_common_type,
&schema,
);
let actual = coerce_case_expression(case, &schema)?;
Expand All @@ -1546,8 +1546,8 @@ mod test {
let then_else_common_type = DataType::Utf8;
let expected = cast_helper(
case.clone(),
case_when_common_type,
then_else_common_type,
&case_when_common_type,
&then_else_common_type,
&schema,
);
let actual = coerce_case_expression(case, &schema)?;
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ impl CommonSubexprEliminate {
fn rewrite_exprs_list<'n>(
&self,
exprs_list: Vec<Vec<Expr>>,
arrays_list: Vec<Vec<IdArray<'n>>>,
arrays_list: &[Vec<IdArray<'n>>],
expr_stats: &ExprStats<'n>,
common_exprs: &mut CommonExprs<'n>,
alias_generator: &AliasGenerator,
Expand Down Expand Up @@ -284,10 +284,10 @@ impl CommonSubexprEliminate {
// Must clone as Identifiers use references to original expressions so we have
// to keep the original expressions intact.
exprs_list.clone(),
id_arrays_list,
&id_arrays_list,
&expr_stats,
&mut common_exprs,
&config.alias_generator(),
config.alias_generator().as_ref(),
)?;
assert!(!common_exprs.is_empty());

Expand Down
14 changes: 7 additions & 7 deletions datafusion/optimizer/src/decorrelate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr {
}

fn f_up(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let subquery_schema = Arc::clone(plan.schema());
let subquery_schema = plan.schema();
match &plan {
LogicalPlan::Filter(plan_filter) => {
let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
Expand Down Expand Up @@ -231,7 +231,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr {
{
proj_exprs_evaluation_result_on_empty_batch(
&projection.expr,
Arc::clone(projection.input.schema()),
projection.input.schema(),
expr_result_map,
&mut expr_result_map_for_count_bug,
)?;
Expand Down Expand Up @@ -277,7 +277,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr {
{
agg_exprs_evaluation_result_on_empty_batch(
&aggregate.aggr_expr,
Arc::clone(aggregate.input.schema()),
aggregate.input.schema(),
&mut expr_result_map_for_count_bug,
)?;
if !expr_result_map_for_count_bug.is_empty() {
Expand Down Expand Up @@ -423,7 +423,7 @@ fn remove_duplicated_filter(filters: Vec<Expr>, in_predicate: &Expr) -> Vec<Expr

fn agg_exprs_evaluation_result_on_empty_batch(
agg_expr: &[Expr],
schema: DFSchemaRef,
schema: &DFSchemaRef,
expr_result_map_for_count_bug: &mut ExprResultMap,
) -> Result<()> {
for e in agg_expr.iter() {
Expand All @@ -446,7 +446,7 @@ fn agg_exprs_evaluation_result_on_empty_batch(

let result_expr = result_expr.unalias();
let props = ExecutionProps::new();
let info = SimplifyContext::new(&props).with_schema(Arc::clone(&schema));
let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema));
let simplifier = ExprSimplifier::new(info);
let result_expr = simplifier.simplify(result_expr)?;
if matches!(result_expr, Expr::Literal(ScalarValue::Int64(_))) {
Expand All @@ -459,7 +459,7 @@ fn agg_exprs_evaluation_result_on_empty_batch(

fn proj_exprs_evaluation_result_on_empty_batch(
proj_expr: &[Expr],
schema: DFSchemaRef,
schema: &DFSchemaRef,
input_expr_result_map_for_count_bug: &ExprResultMap,
expr_result_map_for_count_bug: &mut ExprResultMap,
) -> Result<()> {
Expand All @@ -483,7 +483,7 @@ fn proj_exprs_evaluation_result_on_empty_batch(

if result_expr.ne(expr) {
let props = ExecutionProps::new();
let info = SimplifyContext::new(&props).with_schema(Arc::clone(&schema));
let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema));
let simplifier = ExprSimplifier::new(info);
let result_expr = simplifier.simplify(result_expr)?;
let expr_name = match expr {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/decorrelate_predicate_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
fn build_join(
query_info: &SubqueryInfo,
left: &LogicalPlan,
alias: Arc<AliasGenerator>,
alias: &Arc<AliasGenerator>,
) -> Result<Option<LogicalPlan>> {
let where_in_expr_opt = &query_info.where_in_expr;
let in_predicate_opt = where_in_expr_opt
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
extract_possible_join_keys(left, &mut left_join_keys);
extract_possible_join_keys(right, &mut right_join_keys);

join_keys.insert_intersection(left_join_keys, right_join_keys)
join_keys.insert_intersection(&left_join_keys, &right_join_keys)
}
_ => (),
};
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/join_key_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl JoinKeySet {
}

/// Inserts any join keys that are common to both `s1` and `s2` into self
pub fn insert_intersection(&mut self, s1: JoinKeySet, s2: JoinKeySet) {
pub fn insert_intersection(&mut self, s1: &JoinKeySet, s2: &JoinKeySet) {
// note can't use inner.intersection as we need to consider both (l, r)
// and (r, l) in equality
for (left, right) in s1.inner.iter() {
Expand Down Expand Up @@ -234,7 +234,7 @@ mod test {
let mut set = JoinKeySet::new();
// put something in there already
set.insert(&col("x"), &col("y"));
set.insert_intersection(set1, set2);
set.insert_intersection(&set1, &set2);

assert_contents(
&set,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ fn optimize_projections(
match plan {
LogicalPlan::Projection(proj) => {
return merge_consecutive_projections(proj)?.transform_data(|proj| {
rewrite_projection_given_requirements(proj, config, indices)
rewrite_projection_given_requirements(proj, config, &indices)
})
}
LogicalPlan::Aggregate(aggregate) => {
Expand Down Expand Up @@ -754,7 +754,7 @@ fn add_projection_on_top_if_helpful(
fn rewrite_projection_given_requirements(
proj: Projection,
config: &dyn OptimizerConfig,
indices: RequiredIndicies,
indices: &RequiredIndicies,
) -> Result<Transformed<LogicalPlan>> {
let Projection { expr, input, .. } = proj;

Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ pub trait OptimizerConfig {
fn query_execution_start_time(&self) -> DateTime<Utc>;

/// Return alias generator used to generate unique aliases for subqueries
fn alias_generator(&self) -> Arc<AliasGenerator>;
fn alias_generator(&self) -> &Arc<AliasGenerator>;

fn options(&self) -> &ConfigOptions;

Expand Down Expand Up @@ -204,8 +204,8 @@ impl OptimizerConfig for OptimizerContext {
self.query_execution_start_time
}

fn alias_generator(&self) -> Arc<AliasGenerator> {
Arc::clone(&self.alias_generator)
fn alias_generator(&self) -> &Arc<AliasGenerator> {
&self.alias_generator
}

fn options(&self) -> &ConfigOptions {
Expand Down
8 changes: 4 additions & 4 deletions datafusion/optimizer/src/scalar_subquery_to_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl ScalarSubqueryToJoin {
fn extract_subquery_exprs(
&self,
predicate: &Expr,
alias_gen: Arc<AliasGenerator>,
alias_gen: &Arc<AliasGenerator>,
) -> Result<(Vec<(Subquery, String)>, Expr)> {
let mut extract = ExtractScalarSubQuery {
sub_query_info: vec![],
Expand Down Expand Up @@ -223,12 +223,12 @@ fn contains_scalar_subquery(expr: &Expr) -> bool {
.expect("Inner is always Ok")
}

struct ExtractScalarSubQuery {
struct ExtractScalarSubQuery<'a> {
sub_query_info: Vec<(Subquery, String)>,
alias_gen: Arc<AliasGenerator>,
alias_gen: &'a Arc<AliasGenerator>,
}

impl TreeNodeRewriter for ExtractScalarSubQuery {
impl TreeNodeRewriter for ExtractScalarSubQuery<'_> {
type Node = Expr;

fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
Expand Down
Loading

0 comments on commit 447cb02

Please sign in to comment.