diff --git a/wren-core-py/src/context.rs b/wren-core-py/src/context.rs index 658849acc..979540dd3 100644 --- a/wren-core-py/src/context.rs +++ b/wren-core-py/src/context.rs @@ -158,9 +158,11 @@ impl PySessionContext { }; let manifest = to_manifest(mdl_base64)?; let properties_ref = Arc::new(properties_map); - let Ok(analyzed_mdl) = - AnalyzedWrenMDL::analyze(manifest, Arc::clone(&properties_ref)) - else { + let Ok(analyzed_mdl) = AnalyzedWrenMDL::analyze( + manifest, + Arc::clone(&properties_ref), + mdl::context::Mode::Unparse, + ) else { return Err(CoreError::new("Failed to analyze manifest").into()); }; @@ -172,7 +174,7 @@ impl PySessionContext { &ctx, Arc::clone(&analyzed_mdl), Arc::new(HashMap::new()), - false, + mdl::context::Mode::Unparse, )) .map_err(CoreError::from)?; diff --git a/wren-core-py/src/errors.rs b/wren-core-py/src/errors.rs index bd65d0360..3e89aeb02 100644 --- a/wren-core-py/src/errors.rs +++ b/wren-core-py/src/errors.rs @@ -4,6 +4,8 @@ use pyo3::PyErr; use std::num::ParseIntError; use std::string::FromUtf8Error; use thiserror::Error; +use wren_core::DataFusionError; +use wren_core::WrenError; #[derive(Error, Debug, PartialEq)] #[error("{message}")] @@ -49,8 +51,15 @@ impl From for CoreError { } } -impl From for CoreError { - fn from(err: wren_core::DataFusionError) -> Self { +impl From for CoreError { + fn from(err: DataFusionError) -> Self { + if let DataFusionError::Context(_, ee) = &err { + if let DataFusionError::External(we) = ee.as_ref() { + if let Some(we) = we.downcast_ref::() { + return CoreError::new(we.to_string().as_str()); + } + } + } CoreError::new(err.to_string().as_str()) } } diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index 0bb8ebec9..a531d2dbd 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -364,3 +364,13 @@ def test_clac(): rewritten_sql == "SELECT customer.c_custkey FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM main.customer AS __source) AS customer) AS customer" ) + + session_context = SessionContext(manifest_str, None, properties_hashable) + sql = "SELECT c_name FROM my_catalog.my_schema.customer" + try: + session_context.transform_sql(sql) + except Exception as e: + assert ( + str(e) + == "Permission Denied: No permission to access \"customer\".\"c_name\"" + ) diff --git a/wren-core/benchmarks/src/tpch/run.rs b/wren-core/benchmarks/src/tpch/run.rs index 7a2669b07..66e1e02b4 100644 --- a/wren-core/benchmarks/src/tpch/run.rs +++ b/wren-core/benchmarks/src/tpch/run.rs @@ -8,6 +8,7 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::Instant; use structopt::StructOpt; +use wren_core::mdl::context::Mode; use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL}; #[derive(Debug, StructOpt, Clone)] @@ -53,6 +54,7 @@ impl RunOpt { let mdl = Arc::new(AnalyzedWrenMDL::analyze( tpch_manifest(), Arc::new(HashMap::default()), + Mode::Unparse, )?); let mut millis = vec![]; // run benchmark diff --git a/wren-core/benchmarks/src/wren/run.rs b/wren-core/benchmarks/src/wren/run.rs index 9b2bf5064..a6d9260fb 100644 --- a/wren-core/benchmarks/src/wren/run.rs +++ b/wren-core/benchmarks/src/wren/run.rs @@ -9,6 +9,7 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::Instant; use structopt::StructOpt; +use wren_core::mdl::context::Mode; use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL}; #[derive(Debug, StructOpt, Clone)] @@ -63,6 +64,7 @@ impl RunOpt { let mdl = Arc::new(AnalyzedWrenMDL::analyze( get_manifest(query_id)?, Arc::new(HashMap::default()), + Mode::Unparse, )?); let start = Instant::now(); let sql = &get_query_sql(query_id)?; diff --git a/wren-core/core/src/lib.rs b/wren-core/core/src/lib.rs index 2e2985775..4005857aa 100644 --- a/wren-core/core/src/lib.rs +++ b/wren-core/core/src/lib.rs @@ -6,4 +6,5 @@ pub use datafusion::error::DataFusionError; pub use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}; pub use datafusion::prelude::*; pub use datafusion::sql::sqlparser::*; +pub use logical_plan::error::WrenError; pub use mdl::AnalyzedWrenMDL; diff --git a/wren-core/core/src/logical_plan/analyze/plan.rs b/wren-core/core/src/logical_plan/analyze/plan.rs index 525d7a525..fe2d8be69 100644 --- a/wren-core/core/src/logical_plan/analyze/plan.rs +++ b/wren-core/core/src/logical_plan/analyze/plan.rs @@ -19,8 +19,10 @@ use datafusion::logical_expr::{ use log::debug; use petgraph::Graph; +use crate::logical_plan::analyze::access_control::validate_clac_rule; use crate::logical_plan::analyze::RelationChain; use crate::logical_plan::analyze::RelationChain::Start; +use crate::logical_plan::error::WrenError; use crate::logical_plan::utils::{from_qualified_name, try_map_data_type}; use crate::mdl; use crate::mdl::context::SessionPropertiesRef; @@ -146,6 +148,19 @@ impl ModelPlanNodeBuilder { .any(|expr| is_required_column(expr, column.name())) }); for column in required_columns { + // Actually, it's only be checked in PermissionAnalyze mode. + // In Unparse or LocalRuntime mode, an invalid column won't be registered in the table provider. + // A column accessing will be failed by the column not found error. + if !validate_clac_rule(&column, &self.properties)? { + return Err(DataFusionError::External(Box::new( + WrenError::PermissionDenied(format!( + r#"No permission to access "{}"."{}""#, + model.name(), + column.name + )), + ))); + } + if column.is_calculated { let expr = if column.expression.is_some() { let column_rf = self diff --git a/wren-core/core/src/logical_plan/error.rs b/wren-core/core/src/logical_plan/error.rs new file mode 100644 index 000000000..17979b297 --- /dev/null +++ b/wren-core/core/src/logical_plan/error.rs @@ -0,0 +1,16 @@ +use std::{error::Error, fmt::Display}; + +#[derive(Debug, Clone)] +pub enum WrenError { + PermissionDenied(String), +} + +impl Error for WrenError {} + +impl Display for WrenError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + WrenError::PermissionDenied(msg) => write!(f, "Permission Denied: {msg}"), + } + } +} diff --git a/wren-core/core/src/logical_plan/mod.rs b/wren-core/core/src/logical_plan/mod.rs index 21cd4f4c6..2fb3db423 100644 --- a/wren-core/core/src/logical_plan/mod.rs +++ b/wren-core/core/src/logical_plan/mod.rs @@ -1,3 +1,4 @@ pub mod analyze; +pub mod error; pub mod optimize; pub mod utils; diff --git a/wren-core/core/src/mdl/context.rs b/wren-core/core/src/mdl/context.rs index 380c3bb4a..44ebf1f2b 100644 --- a/wren-core/core/src/mdl/context.rs +++ b/wren-core/core/src/mdl/context.rs @@ -50,7 +50,7 @@ pub async fn create_ctx_with_mdl( ctx: &SessionContext, analyzed_mdl: Arc, properties: SessionPropertiesRef, - is_local_runtime: bool, + mode: Mode, ) -> Result { let session_timezone = properties .get("x-wren-timezone") @@ -88,29 +88,77 @@ pub async fn create_ctx_with_mdl( .collect::>(), ); - let new_state = if is_local_runtime { - new_state.with_analyzer_rules(analyze_rule_for_local_runtime( - Arc::clone(&analyzed_mdl), - reset_default_catalog_schema.clone(), - Arc::clone(&properties), - )) - // The plan will be executed locally, so apply the default optimizer rules + let new_state = new_state.with_analyzer_rules(mode.get_analyze_rules( + Arc::clone(&analyzed_mdl), + Arc::clone(&reset_default_catalog_schema), + Arc::clone(&properties), + )); + let new_state = if let Some(optimize_rules) = mode.get_optimize_rules() { + new_state.with_optimizer_rules(optimize_rules) } else { new_state - .with_analyzer_rules(analyze_rule_for_unparsing( - Arc::clone(&analyzed_mdl), - reset_default_catalog_schema.clone(), - Arc::clone(&properties), - )) - .with_optimizer_rules(optimize_rule_for_unparsing()) }; let new_state = new_state.with_config(config).build(); let ctx = SessionContext::new_with_state(new_state); - register_table_with_mdl(&ctx, analyzed_mdl.wren_mdl(), properties).await?; + register_table_with_mdl(&ctx, analyzed_mdl.wren_mdl(), properties, mode).await?; Ok(ctx) } +/// Execution mode for Wren engine. +#[derive(Debug)] +pub enum Mode { + /// Local runtime mode, used for executing queries by DataFusion directly. + LocalRuntime, + /// Unparse mode, used for generating SQL statements. + /// This mode is used to generate SQL statements that can be executed in other SQL engines. + Unparse, + /// Permission analyze mode, used for analyzing if the error is caused by permission denied. + /// It's only be used when an error is raised during Unparse mode. + PermissionAnalyze, +} + +impl Mode { + pub fn get_analyze_rules( + &self, + analyzed_mdl: Arc, + session_state_ref: SessionStateRef, + properties: SessionPropertiesRef, + ) -> Vec> { + match self { + Mode::LocalRuntime => analyze_rule_for_local_runtime( + Arc::clone(&analyzed_mdl), + Arc::clone(&session_state_ref), + Arc::clone(&properties), + ), + Mode::Unparse => analyze_rule_for_unparsing( + Arc::clone(&analyzed_mdl), + Arc::clone(&session_state_ref), + Arc::clone(&properties), + ), + Mode::PermissionAnalyze => analyze_rule_for_permission( + Arc::clone(&analyzed_mdl), + Arc::clone(&session_state_ref), + Arc::clone(&properties), + ), + } + } + + pub fn get_optimize_rules( + &self, + ) -> Option>> { + match self { + Mode::LocalRuntime => None, + Mode::Unparse => Some(optimize_rule_for_unparsing()), + Mode::PermissionAnalyze => Some(vec![]), + } + } + + pub fn is_permission_analyze(&self) -> bool { + matches!(self, Mode::PermissionAnalyze) + } +} + // Analyzer rules for local runtime fn analyze_rule_for_local_runtime( analyzed_mdl: Arc, @@ -227,10 +275,32 @@ fn optimize_rule_for_unparsing() -> Vec> { ] } +fn analyze_rule_for_permission( + analyzed_mdl: Arc, + session_state_ref: SessionStateRef, + properties: SessionPropertiesRef, +) -> Vec> { + vec![ + // To align the lastest change in datafusion, apply this this rule first. + Arc::new(ExpandWildcardRule::new()), + // expand the view should be the first rule + Arc::new(ExpandWrenViewRule::new( + Arc::clone(&analyzed_mdl), + Arc::clone(&session_state_ref), + )), + Arc::new(ModelAnalyzeRule::new( + Arc::clone(&analyzed_mdl), + Arc::clone(&session_state_ref), + Arc::clone(&properties), + )), + ] +} + pub async fn register_table_with_mdl( ctx: &SessionContext, wren_mdl: Arc, properties: SessionPropertiesRef, + mode: Mode, ) -> Result<()> { let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); @@ -239,7 +309,7 @@ pub async fn register_table_with_mdl( ctx.register_catalog(&wren_mdl.manifest.catalog, Arc::new(catalog)); for model in wren_mdl.manifest.models.iter() { - let table = WrenDataSource::new(Arc::clone(model), &properties)?; + let table = WrenDataSource::new(Arc::clone(model), &properties, &mode)?; ctx.register_table( TableReference::full(wren_mdl.catalog(), wren_mdl.schema(), model.name()), Arc::new(table), @@ -262,12 +332,17 @@ pub struct WrenDataSource { } impl WrenDataSource { - pub fn new(model: Arc, properties: &SessionPropertiesRef) -> Result { + pub fn new( + model: Arc, + properties: &SessionPropertiesRef, + mode: &Mode, + ) -> Result { let available_columns = model .get_physical_columns() .iter() .map(|column| { - if validate_clac_rule(column, properties)? { + if mode.is_permission_analyze() || validate_clac_rule(column, properties)? + { Ok(Some(Arc::clone(column))) } else { Ok(None) diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 9e07d971c..734581928 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -1,7 +1,8 @@ use crate::logical_plan::analyze::access_control::validate_clac_rule; +use crate::logical_plan::error::WrenError; use crate::logical_plan::utils::{from_qualified_name_str, try_map_data_type}; use crate::mdl::builder::ManifestBuilder; -use crate::mdl::context::{create_ctx_with_mdl, WrenDataSource}; +use crate::mdl::context::{create_ctx_with_mdl, Mode, WrenDataSource}; use crate::mdl::function::{ ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, FunctionType, RemoteFunction, @@ -70,9 +71,13 @@ impl Default for AnalyzedWrenMDL { } impl AnalyzedWrenMDL { - pub fn analyze(manifest: Manifest, properties: SessionPropertiesRef) -> Result { + pub fn analyze( + manifest: Manifest, + properties: SessionPropertiesRef, + mode: Mode, + ) -> Result { let wren_mdl = Arc::new(WrenMDL::infer_and_register_remote_table( - manifest, properties, + manifest, properties, mode, )?); let lineage = Arc::new(lineage::Lineage::new(&wren_mdl)?); Ok(AnalyzedWrenMDL { wren_mdl, lineage }) @@ -184,6 +189,7 @@ impl WrenMDL { pub fn infer_and_register_remote_table( manifest: Manifest, properties: SessionPropertiesRef, + mode: Mode, ) -> Result { let mut mdl = WrenMDL::new(manifest); let sources: Vec<_> = mdl @@ -195,7 +201,9 @@ impl WrenMDL { .columns .iter() .map(|column| { - if validate_clac_rule(column, &properties)? { + if mode.is_permission_analyze() + || validate_clac_rule(column, &properties)? + { Ok(Some(Arc::clone(column))) } else { Ok(None) @@ -380,9 +388,33 @@ pub async fn transform_sql_with_ctx( register_remote_function(ctx, remote_function)?; Ok::<_, DataFusionError>(()) })?; - let ctx = - create_ctx_with_mdl(ctx, Arc::clone(&analyzed_mdl), properties, false).await?; - let plan = ctx.state().create_logical_plan(sql).await?; + let ctx = create_ctx_with_mdl( + ctx, + Arc::clone(&analyzed_mdl), + Arc::clone(&properties), + Mode::Unparse, + ) + .await?; + let plan = match ctx.state().create_logical_plan(sql).await { + Ok(plan) => plan, + Err(e) => { + match permission_analyze( + analyzed_mdl.wren_mdl().manifest.clone(), + sql, + remote_functions, + properties, + ) + .await + { + Ok(_) => { + return Err(e); + } + Err(e) => { + return Err(e); + } + } + } + }; debug!("wren-core original plan:\n {plan}"); let analyzed = ctx.state().optimize(&plan)?; debug!("wren-core final planned:\n {analyzed}"); @@ -404,6 +436,58 @@ pub async fn transform_sql_with_ctx( } } +/// Try to check if the fail reason is a permission denied error. +/// +/// In a normal exeuction flow, if a column is not allowed to be used in the model plan, +/// it will return an column not found error because the column won't be registered in the [WrenDataSource]. +/// Through this function, we can check if the error is a permission denied error, then provide a more user-friendly error message. +async fn permission_analyze( + manifest: Manifest, + sql: &str, + remote_functions: &[RemoteFunction], + properties: SessionPropertiesRef, +) -> Result<()> { + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::clone(&properties), + Mode::PermissionAnalyze, + )?); + let ctx = SessionContext::new(); + remote_functions.iter().try_for_each(|remote_function| { + debug!("Registering remote function: {remote_function:?}"); + register_remote_function(&ctx, remote_function)?; + Ok::<_, DataFusionError>(()) + })?; + let ctx = + create_ctx_with_mdl(&ctx, analyzed_mdl, properties, Mode::PermissionAnalyze) + .await?; + + let plan = match ctx.state().create_logical_plan(sql).await { + Ok(plan) => plan, + Err(e) => { + debug!("Failed to create logical plan: {e}"); + return Ok(()); + } + }; + debug!("wren-core start to anlayze:\n {plan}"); + match ctx.state().optimize(&plan) { + Ok(_) => { + info!("SQL is allowed to be planned"); + } + // If the error is a permission denied error, we throw it instead. Otherwise, we throw the original error. + Err(e) => { + if let DataFusionError::Context(_, ee) = &e { + if let DataFusionError::External(we) = ee.as_ref() { + if we.downcast_ref::().is_some() { + return Err(e); + } + } + } + } + } + Ok(()) +} + fn register_remote_function( ctx: &SessionContext, remote_function: &RemoteFunction, @@ -459,7 +543,7 @@ mod test { use std::sync::Arc; use crate::mdl::builder::{ColumnBuilder, ManifestBuilder, ModelBuilder}; - use crate::mdl::context::create_ctx_with_mdl; + use crate::mdl::context::{create_ctx_with_mdl, Mode}; use crate::mdl::function::RemoteFunction; use crate::mdl::manifest::DataSource::MySQL; use crate::mdl::manifest::Manifest; @@ -489,8 +573,11 @@ mod test { Ok(mdl) => mdl, Err(e) => return not_impl_err!("Failed to parse mdl json: {}", e), }; - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + mdl, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); let _ = mdl::transform_sql( Arc::clone(&analyzed_mdl), &[], @@ -511,8 +598,11 @@ mod test { Ok(mdl) => mdl, Err(e) => return not_impl_err!("Failed to parse mdl json: {}", e), }; - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + mdl, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); let tests: Vec<&str> = vec![ "select o_orderkey + o_orderkey from test.test.orders", @@ -555,8 +645,11 @@ mod test { Ok(mdl) => mdl, Err(e) => return not_impl_err!("Failed to parse mdl json: {e}"), }; - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + mdl, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); let sql = "select * from test.test.customer_view"; println!("Original: {sql}"); let _ = transform_sql_with_ctx( @@ -585,8 +678,11 @@ mod test { Ok(mdl) => mdl, Err(e) => return not_impl_err!("Failed to parse mdl json: {e}"), }; - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + mdl, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); let sql = "select totalcost from profile"; let result = transform_sql_with_ctx( &SessionContext::new(), @@ -630,6 +726,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = r#"select * from "CTest"."STest"."Customer""#; let actual = mdl::transform_sql_with_ctx( @@ -674,6 +771,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let actual = transform_sql_with_ctx( &ctx, @@ -747,6 +845,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = r#"select * from wren.test.artist"#; let actual = transform_sql_with_ctx( @@ -818,6 +917,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = r#"select name_append from wren.test.artist"#; let _ = transform_sql_with_ctx( @@ -876,6 +976,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = r#"select "串接名字" from wren.test.artist"#; let actual = transform_sql_with_ctx( @@ -951,6 +1052,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = r#"select current_date > "出道時間" from wren.test.artist"#; let actual = transform_sql_with_ctx( @@ -1057,6 +1159,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = "select * from unnest([1, 2, 3])"; let actual = transform_sql_with_ctx( @@ -1075,6 +1178,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = "select * from unnest([1, 2, 3])"; let actual = transform_sql_with_ctx( @@ -1209,6 +1313,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = r#"select count(*) from wren.test.artist where cast(cast_timestamptz as timestamp) > timestamp '2011-01-01 21:00:00'"#; let actual = transform_sql_with_ctx( @@ -1260,9 +1365,13 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, registers)?); let properties_ref = Arc::new(HashMap::new()); - let ctx = - create_ctx_with_mdl(&ctx, Arc::clone(&analyzed_mdl), properties_ref, true) - .await?; + let ctx = create_ctx_with_mdl( + &ctx, + Arc::clone(&analyzed_mdl), + properties_ref, + Mode::LocalRuntime, + ) + .await?; let sql = r#"select arrow_typeof(timestamp_col), arrow_typeof(timestamptz_col) from wren.test.timestamp_table limit 1"#; let result = ctx.sql(sql).await?.collect().await?; assert_snapshot!(batches_to_string(&result), @r#" @@ -1300,6 +1409,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = r#"select timestamp_col = timestamptz_col from wren.test.timestamp_table"#; let actual = transform_sql_with_ctx( @@ -1377,6 +1487,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = "select list_col[1] from wren.test.list_table"; let actual = transform_sql_with_ctx( @@ -1421,6 +1532,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = "select struct_col.float_field from wren.test.struct_table"; let actual = transform_sql_with_ctx( @@ -1476,6 +1588,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = "select struct_col.float_field from wren.test.struct_table"; let _ = transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), sql) @@ -1539,6 +1652,7 @@ mod test { let mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let ctx = SessionContext::new(); let sql = "SELECT trim(' abc')"; @@ -1572,6 +1686,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let result = transform_sql_with_ctx( &ctx, @@ -1609,6 +1724,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let result = transform_sql_with_ctx( &ctx, @@ -1645,6 +1761,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let result = transform_sql_with_ctx( &ctx, @@ -1697,6 +1814,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let result = transform_sql_with_ctx( &ctx, @@ -1749,6 +1867,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let result = transform_sql_with_ctx( &ctx, @@ -1791,6 +1910,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = "SELECT * FROM customer"; let headers = @@ -1854,6 +1974,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = "SELECT * FROM customer"; let headers = Arc::new(build_headers(&[ @@ -1932,6 +2053,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = "SELECT * FROM customer"; @@ -2003,6 +2125,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = "SELECT * FROM customer"; let headers = Arc::new(build_headers(&[( @@ -2039,6 +2162,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let headers = Arc::new(build_headers(&[( "session_nation".to_string(), @@ -2080,6 +2204,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let headers = Arc::new(build_headers(&[( "session_nation".to_string(), @@ -2135,6 +2260,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let headers = Arc::new(build_headers(&[( "session_nation".to_string(), @@ -2215,6 +2341,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let headers = Arc::new(build_headers(&[( "session_user".to_string(), @@ -2301,6 +2428,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let headers = Arc::new(build_headers(&[( "session_nation".to_string(), @@ -2368,6 +2496,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let headers = Arc::new(build_headers(&[( "session_nation".to_string(), @@ -2422,8 +2551,11 @@ mod test { "session_level".to_string(), Some("1".to_string()), )])); - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest.clone(), + headers.clone(), + Mode::Unparse, + )?); let sql = "SELECT * FROM customer"; assert_snapshot!( @@ -2435,15 +2567,18 @@ mod test { "session_level".to_string(), Some("0".to_string()), )])); - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest.clone(), + headers.clone(), + Mode::Unparse, + )?); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, @"SELECT customer.c_custkey FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM customer AS __source) AS customer) AS customer" ); let headers = Arc::new(HashMap::default()); - match AnalyzedWrenMDL::analyze(manifest, headers.clone()) { + match AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone(), Mode::Unparse) { Err(e) => { assert_snapshot!( e.to_string(), @@ -2453,6 +2588,35 @@ mod test { _ => panic!("Expected error"), } + let headers = Arc::new(build_headers(&[( + "session_level".to_string(), + Some("0".to_string()), + )])); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest.clone(), + headers.clone(), + Mode::Unparse, + )?); + let sql = "SELECT c_name FROM customer"; + + match transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await + { + Err(e) => { + assert_snapshot!( + e.to_string(), + @r#" + ModelAnalyzeRule + caused by + External error: Permission Denied: No permission to access "customer"."c_name" + "# + ) + } + Ok(sql) => { + panic!("Expected error, but got SQL: {sql}"); + } + } + Ok(()) } @@ -2487,8 +2651,11 @@ mod test { "session_level".to_string(), Some("1".to_string()), )])); - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest.clone(), + headers.clone(), + Mode::Unparse, + )?); let sql = "SELECT * FROM customer"; assert_snapshot!( @@ -2500,8 +2667,11 @@ mod test { "session_level".to_string(), Some("0".to_string()), )])); - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest.clone(), + headers.clone(), + Mode::Unparse, + )?); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, @"SELECT customer.c_custkey FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM customer AS __source) AS customer) AS customer" @@ -2509,8 +2679,11 @@ mod test { // test the rule is applied the default value if the optional property is None let headers = Arc::new(HashMap::default()); - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest.clone(), + headers.clone(), + Mode::Unparse, + )?); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, @"SELECT customer.c_custkey FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM customer AS __source) AS customer) AS customer" @@ -2543,8 +2716,11 @@ mod test { // test the rule is skipped when the optional property is None let headers = Arc::new(HashMap::default()); - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest.clone(), + headers.clone(), + Mode::Unparse, + )?); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, @"SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM customer AS __source) AS customer) AS customer" @@ -2585,8 +2761,11 @@ mod test { "session_level".to_string(), Some("1".to_string()), )])); - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest.clone(), + headers.clone(), + Mode::Unparse, + )?); let sql = "SELECT c_name_upper FROM customer"; assert_snapshot!( @@ -2598,8 +2777,11 @@ mod test { "session_level".to_string(), Some("0".to_string()), )])); - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest.clone(), + headers.clone(), + Mode::Unparse, + )?); match transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) .await @@ -2643,6 +2825,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = "SELECT * FROM customer"; let headers = Arc::new(build_headers(&[( @@ -2675,6 +2858,7 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = "SELECT * FROM customer limit 0"; let headers = Arc::new(HashMap::default()); diff --git a/wren-core/core/src/mdl/utils.rs b/wren-core/core/src/mdl/utils.rs index 770a8c842..ebc54fa47 100644 --- a/wren-core/core/src/mdl/utils.rs +++ b/wren-core/core/src/mdl/utils.rs @@ -267,6 +267,7 @@ mod tests { use datafusion::prelude::SessionContext; use crate::logical_plan::utils::from_qualified_name; + use crate::mdl::context::Mode; use crate::mdl::manifest::Manifest; use crate::mdl::AnalyzedWrenMDL; @@ -278,8 +279,11 @@ mod tests { .collect(); let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); let mdl = serde_json::from_str::(&mdl_json).unwrap(); - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + mdl, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); let ctx = SessionContext::new(); let column_rf = analyzed_mdl .wren_mdl @@ -307,8 +311,11 @@ mod tests { .collect(); let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); let mdl = serde_json::from_str::(&mdl_json).unwrap(); - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + mdl, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); let ctx = SessionContext::new(); let column_rf = analyzed_mdl .wren_mdl @@ -357,8 +364,11 @@ mod tests { .collect(); let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); let mdl = serde_json::from_str::(&mdl_json).unwrap(); - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + mdl, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); let ctx = SessionContext::new(); let model = analyzed_mdl.wren_mdl().get_model("customer").unwrap(); let expr = super::create_wren_expr_for_model( @@ -378,8 +388,11 @@ mod tests { .collect(); let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); let mdl = serde_json::from_str::(&mdl_json).unwrap(); - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + mdl, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); let ctx = SessionContext::new(); let model = analyzed_mdl.wren_mdl().get_model("customer").unwrap(); let expr = super::create_remote_expr_for_model( diff --git a/wren-core/sqllogictest/src/test_context.rs b/wren-core/sqllogictest/src/test_context.rs index 8cf159da0..fa80c6656 100644 --- a/wren-core/sqllogictest/src/test_context.rs +++ b/wren-core/sqllogictest/src/test_context.rs @@ -28,7 +28,7 @@ use tempfile::TempDir; use wren_core::mdl::builder::{ ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder, ViewBuilder, }; -use wren_core::mdl::context::create_ctx_with_mdl; +use wren_core::mdl::context::{create_ctx_with_mdl, Mode}; use wren_core::mdl::manifest::JoinType; use wren_core::mdl::AnalyzedWrenMDL; @@ -305,7 +305,7 @@ async fn register_ecommerce_mdl( ctx, Arc::clone(&analyzed_mdl), Arc::new(HashMap::new()), - true, + Mode::LocalRuntime, ) .await?; Ok((ctx.to_owned(), analyzed_mdl)) @@ -541,7 +541,7 @@ async fn register_tpch_mdl( ctx, Arc::clone(&analyzed_mdl), Arc::new(HashMap::new()), - true, + Mode::LocalRuntime, ) .await?; Ok((ctx.to_owned(), analyzed_mdl)) diff --git a/wren-core/wren-example/examples/plan-sql.rs b/wren-core/wren-example/examples/plan-sql.rs index 82de858cd..130ed9693 100644 --- a/wren-core/wren-example/examples/plan-sql.rs +++ b/wren-core/wren-example/examples/plan-sql.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use wren_core::mdl::builder::{ ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder, }; +use wren_core::mdl::context::Mode; use wren_core::mdl::manifest::{JoinType, Manifest}; use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL}; @@ -13,6 +14,7 @@ async fn main() -> datafusion::common::Result<()> { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = "select customer_state from wrenai.public.orders_model"; diff --git a/wren-core/wren-example/examples/to-many-calculation.rs b/wren-core/wren-example/examples/to-many-calculation.rs index db1d9a84a..60f2a3fd1 100644 --- a/wren-core/wren-example/examples/to-many-calculation.rs +++ b/wren-core/wren-example/examples/to-many-calculation.rs @@ -7,7 +7,7 @@ use datafusion::prelude::{CsvReadOptions, SessionContext}; use wren_core::mdl::builder::{ ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder, }; -use wren_core::mdl::context::create_ctx_with_mdl; +use wren_core::mdl::context::{create_ctx_with_mdl, Mode}; use wren_core::mdl::manifest::{JoinType, Manifest}; use wren_core::mdl::AnalyzedWrenMDL; @@ -76,8 +76,13 @@ async fn main() -> Result<()> { ]); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, register)?); - let ctx = - create_ctx_with_mdl(&ctx, analyzed_mdl, Arc::new(HashMap::new()), true).await?; + let ctx = create_ctx_with_mdl( + &ctx, + analyzed_mdl, + Arc::new(HashMap::new()), + Mode::LocalRuntime, + ) + .await?; let df = match ctx .sql("select totalprice from wrenai.public.customers") .await diff --git a/wren-core/wren-example/examples/view.rs b/wren-core/wren-example/examples/view.rs index a2afd061a..e7515d2f2 100644 --- a/wren-core/wren-example/examples/view.rs +++ b/wren-core/wren-example/examples/view.rs @@ -6,6 +6,7 @@ use datafusion::prelude::SessionContext; use wren_core::mdl::builder::{ ColumnBuilder, ManifestBuilder, ModelBuilder, ViewBuilder, }; +use wren_core::mdl::context::Mode; use wren_core::mdl::manifest::Manifest; use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL}; @@ -15,6 +16,7 @@ async fn main() -> datafusion::common::Result<()> { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( manifest, Arc::new(HashMap::default()), + Mode::Unparse, )?); let sql = "select * from wrenai.public.customers_view";