From 5d317c7774f12431a942873186c0dbad10659e14 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 22 Jun 2024 20:02:22 +0800 Subject: [PATCH] feat: Add method to add analyzer rules to SessionContext (#10849) * feat: Add method to add analyzer rules to SessionContext Signed-off-by: Kevin Su * Add a test Signed-off-by: Kevin Su * Add analyze_plan Signed-off-by: Kevin Su * update test Signed-off-by: Kevin Su --------- Signed-off-by: Kevin Su Co-authored-by: Andrew Lamb --- datafusion/core/src/execution/context/mod.rs | 10 ++ .../core/src/execution/session_state.rs | 4 +- .../tests/user_defined/user_defined_plan.rs | 99 ++++++++++++++++++- 3 files changed, 107 insertions(+), 6 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 6fa83d3d931e8..c44e9742607e5 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -75,6 +75,7 @@ use url::Url; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; +use datafusion_optimizer::AnalyzerRule; mod avro; mod csv; @@ -331,6 +332,15 @@ impl SessionContext { self } + /// Adds an analyzer rule to the `SessionState` in the current `SessionContext`. + pub fn add_analyzer_rule( + self, + analyzer_rule: Arc, + ) -> Self { + self.state.write().add_analyzer_rule(analyzer_rule); + self + } + /// Registers an [`ObjectStore`] to be used with a specific URL prefix. /// /// See [`RuntimeEnv::register_object_store`] for more details. diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 1df77a1f9e0be..e9441a89cd5f0 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -384,9 +384,9 @@ impl SessionState { /// Add `analyzer_rule` to the end of the list of /// [`AnalyzerRule`]s used to rewrite queries. pub fn add_analyzer_rule( - mut self, + &mut self, analyzer_rule: Arc, - ) -> Self { + ) -> &Self { self.analyzer.rules.push(analyzer_rule); self } diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index ebf907c5e2c08..c5654ded888ad 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -92,8 +92,12 @@ use datafusion::{ }; use async_trait::async_trait; -use datafusion_common::tree_node::Transformed; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::ScalarValue; +use datafusion_expr::Projection; use datafusion_optimizer::optimizer::ApplyOrder; +use datafusion_optimizer::AnalyzerRule; use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches @@ -132,11 +136,13 @@ async fn setup_table_without_schemas(mut ctx: SessionContext) -> Result Result<()> { @@ -164,6 +170,34 @@ async fn run_and_compare_query(mut ctx: SessionContext, description: &str) -> Re Ok(()) } +// Run the query using the specified execution context and compare it +// to the known result +async fn run_and_compare_query_with_analyzer_rule( + mut ctx: SessionContext, + description: &str, +) -> Result<()> { + let expected = vec![ + "+------------+--------------------------+", + "| UInt64(42) | arrow_typeof(UInt64(42)) |", + "+------------+--------------------------+", + "| 42 | UInt64 |", + "+------------+--------------------------+", + ]; + + let s = exec_sql(&mut ctx, QUERY2).await?; + let actual = s.lines().collect::>(); + + assert_eq!( + expected, + actual, + "output mismatch for {}. Expectedn\n{}Actual:\n{}", + description, + expected.join("\n"), + s + ); + Ok(()) +} + // Run the query using the specified execution context and compare it // to the known result async fn run_and_compare_query_with_auto_schemas( @@ -208,6 +242,13 @@ async fn normal_query() -> Result<()> { run_and_compare_query(ctx, "Default context").await } +#[tokio::test] +// Run the query using default planners, optimizer and custom analyzer rule +async fn normal_query_with_analyzer() -> Result<()> { + let ctx = SessionContext::new().add_analyzer_rule(Arc::new(MyAnalyzerRule {})); + run_and_compare_query_with_analyzer_rule(ctx, "MyAnalyzerRule").await +} + #[tokio::test] // Run the query using topk optimization async fn topk_query() -> Result<()> { @@ -248,9 +289,10 @@ async fn topk_plan() -> Result<()> { fn make_topk_context() -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime) + let mut state = SessionState::new_with_config_rt(config, runtime) .with_query_planner(Arc::new(TopKQueryPlanner {})) .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); + state.add_analyzer_rule(Arc::new(MyAnalyzerRule {})); SessionContext::new_with_state(state) } @@ -633,3 +675,52 @@ impl RecordBatchStream for TopKReader { self.input.schema() } } + +struct MyAnalyzerRule {} + +impl AnalyzerRule for MyAnalyzerRule { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + Self::analyze_plan(plan) + } + + fn name(&self) -> &str { + "my_analyzer_rule" + } +} + +impl MyAnalyzerRule { + fn analyze_plan(plan: LogicalPlan) -> Result { + plan.transform(|plan| { + Ok(match plan { + LogicalPlan::Projection(projection) => { + let expr = Self::analyze_expr(projection.expr.clone())?; + Transformed::yes(LogicalPlan::Projection(Projection::try_new( + expr, + projection.input, + )?)) + } + _ => Transformed::no(plan), + }) + }) + .data() + } + + fn analyze_expr(expr: Vec) -> Result> { + expr.into_iter() + .map(|e| { + e.transform(|e| { + Ok(match e { + Expr::Literal(ScalarValue::Int64(i)) => { + // transform to UInt64 + Transformed::yes(Expr::Literal(ScalarValue::UInt64( + i.map(|i| i as u64), + ))) + } + _ => Transformed::no(e), + }) + }) + .data() + }) + .collect() + } +}