Skip to content

Commit

Permalink
feat: Add method to add analyzer rules to SessionContext (apache#10849)
Browse files Browse the repository at this point in the history
* feat: Add method to add analyzer rules to SessionContext

Signed-off-by: Kevin Su <[email protected]>

* Add a test

Signed-off-by: Kevin Su <[email protected]>

* Add analyze_plan

Signed-off-by: Kevin Su <[email protected]>

* update test

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
2 people authored and xinlifoobar committed Jun 22, 2024
1 parent 1a6c4a2 commit 5d317c7
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 6 deletions.
10 changes: 10 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ use url::Url;
pub use datafusion_execution::config::SessionConfig;
pub use datafusion_execution::TaskContext;
pub use datafusion_expr::execution_props::ExecutionProps;
use datafusion_optimizer::AnalyzerRule;

mod avro;
mod csv;
Expand Down Expand Up @@ -331,6 +332,15 @@ impl SessionContext {
self
}

/// Adds an analyzer rule to the `SessionState` in the current `SessionContext`.
pub fn add_analyzer_rule(
self,
analyzer_rule: Arc<dyn AnalyzerRule + Send + Sync>,
) -> Self {
self.state.write().add_analyzer_rule(analyzer_rule);
self
}

/// Registers an [`ObjectStore`] to be used with a specific URL prefix.
///
/// See [`RuntimeEnv::register_object_store`] for more details.
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,9 @@ impl SessionState {
/// Add `analyzer_rule` to the end of the list of
/// [`AnalyzerRule`]s used to rewrite queries.
pub fn add_analyzer_rule(
mut self,
&mut self,
analyzer_rule: Arc<dyn AnalyzerRule + Send + Sync>,
) -> Self {
) -> &Self {
self.analyzer.rules.push(analyzer_rule);
self
}
Expand Down
99 changes: 95 additions & 4 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ use datafusion::{
};

use async_trait::async_trait;
use datafusion_common::tree_node::Transformed;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::ScalarValue;
use datafusion_expr::Projection;
use datafusion_optimizer::optimizer::ApplyOrder;
use datafusion_optimizer::AnalyzerRule;
use futures::{Stream, StreamExt};

/// Execute the specified sql and return the resulting record batches
Expand Down Expand Up @@ -132,11 +136,13 @@ async fn setup_table_without_schemas(mut ctx: SessionContext) -> Result<SessionC
Ok(ctx)
}

const QUERY1: &str = "SELECT * FROM sales limit 3";

const QUERY: &str =
"SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3";

const QUERY1: &str = "SELECT * FROM sales limit 3";

const QUERY2: &str = "SELECT 42, arrow_typeof(42)";

// Run the query using the specified execution context and compare it
// to the known result
async fn run_and_compare_query(mut ctx: SessionContext, description: &str) -> Result<()> {
Expand Down Expand Up @@ -164,6 +170,34 @@ async fn run_and_compare_query(mut ctx: SessionContext, description: &str) -> Re
Ok(())
}

// Run the query using the specified execution context and compare it
// to the known result
async fn run_and_compare_query_with_analyzer_rule(
mut ctx: SessionContext,
description: &str,
) -> Result<()> {
let expected = vec![
"+------------+--------------------------+",
"| UInt64(42) | arrow_typeof(UInt64(42)) |",
"+------------+--------------------------+",
"| 42 | UInt64 |",
"+------------+--------------------------+",
];

let s = exec_sql(&mut ctx, QUERY2).await?;
let actual = s.lines().collect::<Vec<_>>();

assert_eq!(
expected,
actual,
"output mismatch for {}. Expectedn\n{}Actual:\n{}",
description,
expected.join("\n"),
s
);
Ok(())
}

// Run the query using the specified execution context and compare it
// to the known result
async fn run_and_compare_query_with_auto_schemas(
Expand Down Expand Up @@ -208,6 +242,13 @@ async fn normal_query() -> Result<()> {
run_and_compare_query(ctx, "Default context").await
}

#[tokio::test]
// Run the query using default planners, optimizer and custom analyzer rule
async fn normal_query_with_analyzer() -> Result<()> {
let ctx = SessionContext::new().add_analyzer_rule(Arc::new(MyAnalyzerRule {}));
run_and_compare_query_with_analyzer_rule(ctx, "MyAnalyzerRule").await
}

#[tokio::test]
// Run the query using topk optimization
async fn topk_query() -> Result<()> {
Expand Down Expand Up @@ -248,9 +289,10 @@ async fn topk_plan() -> Result<()> {
fn make_topk_context() -> SessionContext {
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionState::new_with_config_rt(config, runtime)
let mut state = SessionState::new_with_config_rt(config, runtime)
.with_query_planner(Arc::new(TopKQueryPlanner {}))
.add_optimizer_rule(Arc::new(TopKOptimizerRule {}));
state.add_analyzer_rule(Arc::new(MyAnalyzerRule {}));
SessionContext::new_with_state(state)
}

Expand Down Expand Up @@ -633,3 +675,52 @@ impl RecordBatchStream for TopKReader {
self.input.schema()
}
}

struct MyAnalyzerRule {}

impl AnalyzerRule for MyAnalyzerRule {
fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
Self::analyze_plan(plan)
}

fn name(&self) -> &str {
"my_analyzer_rule"
}
}

impl MyAnalyzerRule {
fn analyze_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
plan.transform(|plan| {
Ok(match plan {
LogicalPlan::Projection(projection) => {
let expr = Self::analyze_expr(projection.expr.clone())?;
Transformed::yes(LogicalPlan::Projection(Projection::try_new(
expr,
projection.input,
)?))
}
_ => Transformed::no(plan),
})
})
.data()
}

fn analyze_expr(expr: Vec<Expr>) -> Result<Vec<Expr>> {
expr.into_iter()
.map(|e| {
e.transform(|e| {
Ok(match e {
Expr::Literal(ScalarValue::Int64(i)) => {
// transform to UInt64
Transformed::yes(Expr::Literal(ScalarValue::UInt64(
i.map(|i| i as u64),
)))
}
_ => Transformed::no(e),
})
})
.data()
})
.collect()
}
}

0 comments on commit 5d317c7

Please sign in to comment.