From 7378183ae09cff8c29b9decad8836938e2d3e923 Mon Sep 17 00:00:00 2001 From: Jay Chia Date: Fri, 16 Aug 2024 18:46:15 -0700 Subject: [PATCH] Allow merging of Project into ActorPoolProject if the Project doesnt have computation --- .../src/logical_ops/actor_pool_project.rs | 29 +++++- .../rules/push_down_projection.rs | 93 +++++++++++++++++-- 2 files changed, 114 insertions(+), 8 deletions(-) diff --git a/src/daft-plan/src/logical_ops/actor_pool_project.rs b/src/daft-plan/src/logical_ops/actor_pool_project.rs index 63457cbff5..e991431b94 100644 --- a/src/daft-plan/src/logical_ops/actor_pool_project.rs +++ b/src/daft-plan/src/logical_ops/actor_pool_project.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use common_error::DaftError; use common_resource_request::ResourceRequest; use common_treenode::TreeNode; use daft_core::schema::{Schema, SchemaRef}; @@ -14,7 +15,7 @@ use itertools::Itertools; use snafu::ResultExt; use crate::{ - logical_plan::{CreationSnafu, Result}, + logical_plan::{CreationSnafu, Error, Result}, LogicalPlan, }; @@ -30,7 +31,33 @@ impl ActorPoolProject { pub(crate) fn try_new(input: Arc, projection: Vec) -> Result { let (projection, fields) = resolve_exprs(projection, input.schema().as_ref()).context(CreationSnafu)?; + + let num_stateful_udf_exprs: usize = projection + .iter() + .map(|expr| { + let mut num_stateful_udfs = 0; + expr.apply(|e| { + if matches!( + e.as_ref(), + Expr::Function { + func: FunctionExpr::Python(PythonUDF::Stateful(_)), + .. + } + ) { + num_stateful_udfs += 1; + } + Ok(common_treenode::TreeNodeRecursion::Continue) + }) + .unwrap(); + num_stateful_udfs + }) + .sum(); + if !num_stateful_udf_exprs == 1 { + return Err(Error::CreationError { source: DaftError::InternalError(format!("Expected ActorPoolProject to have exactly 1 stateful UDF expression but found: {num_stateful_udf_exprs}")) }); + } + let projected_schema = Schema::new(fields).context(CreationSnafu)?.into(); + Ok(ActorPoolProject { input, projection, diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs index fa13ef796a..046e5028d7 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs @@ -4,8 +4,14 @@ use common_error::DaftResult; use common_treenode::TreeNode; use daft_core::{schema::Schema, JoinType}; -use daft_dsl::{col, optimization::replace_columns_with_expressions, Expr, ExprRef}; +use daft_dsl::{ + col, + functions::{python::PythonUDF, FunctionExpr}, + optimization::{get_required_columns, replace_columns_with_expressions, requires_computation}, + Expr, ExprRef, +}; use indexmap::IndexSet; +use itertools::Itertools; use crate::{ logical_ops::{ActorPoolProject, Aggregate, Join, Pivot, Project, Source}, @@ -229,6 +235,71 @@ impl PushDownProjection { } } LogicalPlan::ActorPoolProject(upstream_actor_pool_projection) => { + // Attempt to merge the current Projection into the upstream ActorPoolProject + // if there aren't any actual computations being performed in the Projection, and + // if each upstream column is used only once (no common subtrees) + if projection + .projection + .iter() + .all(|e| !requires_computation(e)) + { + let required_column_names = projection + .projection + .iter() + .flat_map(get_required_columns) + .collect_vec(); + let distinct_required_column_names = + required_column_names.iter().collect::>().len(); + if required_column_names.len() == distinct_required_column_names { + let actor_pool_projection_map = upstream_actor_pool_projection + .projection + .iter() + .map(|e| (e.name().to_string(), e.clone())) + .collect::>(); + let new_actor_pool_projections = projection + .projection + .iter() + .map(|p| { + replace_columns_with_expressions( + p.clone(), + &actor_pool_projection_map, + ) + }) + .collect_vec(); + + // Construct either a new ActorPoolProject or Project, depending on whether the pruned projection still has StatefulUDFs + let new_plan = if new_actor_pool_projections.iter().any(|e| { + e.exists(|e| { + matches!( + e.as_ref(), + Expr::Function { + func: FunctionExpr::Python(PythonUDF::Stateful(_)), + .. + } + ) + }) + }) { + LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( + upstream_actor_pool_projection.input.clone(), + new_actor_pool_projections, + )?) + .arced() + } else { + LogicalPlan::Project(Project::try_new( + upstream_actor_pool_projection.input.clone(), + new_actor_pool_projections, + )?) + .arced() + }; + + // Retry optimization now that the node is different. + let new_plan = self + .try_optimize(new_plan.clone())? + .or(Transformed::Yes(new_plan)); + return Ok(new_plan); + } + } + // Prune columns from the child ActorPoolProjection that are not used in this projection. let required_columns = &plan.required_columns()[0]; if required_columns.len() < upstream_schema.names().len() { @@ -841,7 +912,7 @@ mod tests { Field::new("b", DataType::Boolean), Field::new("c", DataType::Int64), ]); - let scan_node = dummy_scan_node(scan_op).build(); + let scan_node = dummy_scan_node(scan_op.clone()); let mock_stateful_udf = Expr::Function { func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { name: Arc::new("my-udf".to_string()), @@ -857,7 +928,7 @@ mod tests { // Select the `udf_results` column, so the ActorPoolProject should apply column pruning to the other columns let actor_pool_project = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( - scan_node.clone(), + scan_node.build(), vec![col("a"), col("b"), mock_stateful_udf.alias("udf_results")], )?) .arced(); @@ -868,7 +939,11 @@ mod tests { .arced(); let expected_actor_pool_project = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( - scan_node.clone(), + dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_columns(Some(Arc::new(vec!["c".to_string()]))), + ) + .build(), vec![mock_stateful_udf.alias("udf_results")], )?) .arced(); @@ -892,7 +967,7 @@ mod tests { Field::new("b", DataType::Boolean), Field::new("c", DataType::Int64), ]); - let scan_node = dummy_scan_node(scan_op).build(); + let scan_node = dummy_scan_node(scan_op.clone()).build(); let mock_stateful_udf = Expr::Function { func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF { name: Arc::new("my-udf".to_string()), @@ -929,8 +1004,12 @@ mod tests { .arced(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( - scan_node.clone(), - vec![col("a"), mock_stateful_udf.alias("udf_results_0")], + dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_columns(Some(Arc::new(vec!["a".to_string()]))), + ) + .build(), + vec![mock_stateful_udf.alias("udf_results_0"), col("a")], )?) .arced(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new(