Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions wren-core-py/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,11 @@ impl PySessionContext {
};
let manifest = to_manifest(mdl_base64)?;
let properties_ref = Arc::new(properties_map);
let Ok(analyzed_mdl) =
AnalyzedWrenMDL::analyze(manifest, Arc::clone(&properties_ref))
else {
let Ok(analyzed_mdl) = AnalyzedWrenMDL::analyze(
manifest,
Arc::clone(&properties_ref),
mdl::context::Mode::Unparse,
) else {
return Err(CoreError::new("Failed to analyze manifest").into());
};

Expand All @@ -172,7 +174,7 @@ impl PySessionContext {
&ctx,
Arc::clone(&analyzed_mdl),
Arc::new(HashMap::new()),
false,
mdl::context::Mode::Unparse,
))
.map_err(CoreError::from)?;

Expand Down
13 changes: 11 additions & 2 deletions wren-core-py/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use pyo3::PyErr;
use std::num::ParseIntError;
use std::string::FromUtf8Error;
use thiserror::Error;
use wren_core::DataFusionError;
use wren_core::WrenError;

#[derive(Error, Debug, PartialEq)]
#[error("{message}")]
Expand Down Expand Up @@ -49,8 +51,15 @@ impl From<serde_json::Error> for CoreError {
}
}

impl From<wren_core::DataFusionError> for CoreError {
fn from(err: wren_core::DataFusionError) -> Self {
impl From<DataFusionError> for CoreError {
fn from(err: DataFusionError) -> Self {
if let DataFusionError::Context(_, ee) = &err {
if let DataFusionError::External(we) = ee.as_ref() {
if let Some(we) = we.downcast_ref::<WrenError>() {
return CoreError::new(we.to_string().as_str());
}
}
}
CoreError::new(err.to_string().as_str())
}
}
Expand Down
10 changes: 10 additions & 0 deletions wren-core-py/tests/test_modeling_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,13 @@ def test_clac():
rewritten_sql
== "SELECT customer.c_custkey FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM main.customer AS __source) AS customer) AS customer"
)

session_context = SessionContext(manifest_str, None, properties_hashable)
sql = "SELECT c_name FROM my_catalog.my_schema.customer"
try:
session_context.transform_sql(sql)
except Exception as e:
assert (
str(e)
== "Permission Denied: No permission to access \"customer\".\"c_name\""
)
2 changes: 2 additions & 0 deletions wren-core/benchmarks/src/tpch/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use structopt::StructOpt;
use wren_core::mdl::context::Mode;
use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL};

#[derive(Debug, StructOpt, Clone)]
Expand Down Expand Up @@ -53,6 +54,7 @@ impl RunOpt {
let mdl = Arc::new(AnalyzedWrenMDL::analyze(
tpch_manifest(),
Arc::new(HashMap::default()),
Mode::Unparse,
)?);
let mut millis = vec![];
// run benchmark
Expand Down
2 changes: 2 additions & 0 deletions wren-core/benchmarks/src/wren/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use structopt::StructOpt;
use wren_core::mdl::context::Mode;
use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL};

#[derive(Debug, StructOpt, Clone)]
Expand Down Expand Up @@ -63,6 +64,7 @@ impl RunOpt {
let mdl = Arc::new(AnalyzedWrenMDL::analyze(
get_manifest(query_id)?,
Arc::new(HashMap::default()),
Mode::Unparse,
)?);
let start = Instant::now();
let sql = &get_query_sql(query_id)?;
Expand Down
1 change: 1 addition & 0 deletions wren-core/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ pub use datafusion::error::DataFusionError;
pub use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
pub use datafusion::prelude::*;
pub use datafusion::sql::sqlparser::*;
pub use logical_plan::error::WrenError;
pub use mdl::AnalyzedWrenMDL;
15 changes: 15 additions & 0 deletions wren-core/core/src/logical_plan/analyze/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ use datafusion::logical_expr::{
use log::debug;
use petgraph::Graph;

use crate::logical_plan::analyze::access_control::validate_clac_rule;
use crate::logical_plan::analyze::RelationChain;
use crate::logical_plan::analyze::RelationChain::Start;
use crate::logical_plan::error::WrenError;
use crate::logical_plan::utils::{from_qualified_name, try_map_data_type};
use crate::mdl;
use crate::mdl::context::SessionPropertiesRef;
Expand Down Expand Up @@ -146,6 +148,19 @@ impl ModelPlanNodeBuilder {
.any(|expr| is_required_column(expr, column.name()))
});
for column in required_columns {
// Actually, it's only be checked in PermissionAnalyze mode.
// In Unparse or LocalRuntime mode, an invalid column won't be registered in the table provider.
// A column accessing will be failed by the column not found error.
if !validate_clac_rule(&column, &self.properties)? {
return Err(DataFusionError::External(Box::new(
WrenError::PermissionDenied(format!(
r#"No permission to access "{}"."{}""#,
model.name(),
column.name
)),
)));
}

if column.is_calculated {
let expr = if column.expression.is_some() {
let column_rf = self
Expand Down
16 changes: 16 additions & 0 deletions wren-core/core/src/logical_plan/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use std::{error::Error, fmt::Display};

#[derive(Debug, Clone)]
pub enum WrenError {
PermissionDenied(String),
}

impl Error for WrenError {}

impl Display for WrenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WrenError::PermissionDenied(msg) => write!(f, "Permission Denied: {msg}"),
}
}
}
1 change: 1 addition & 0 deletions wren-core/core/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod analyze;
pub mod error;
pub mod optimize;
pub mod utils;
111 changes: 93 additions & 18 deletions wren-core/core/src/mdl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub async fn create_ctx_with_mdl(
ctx: &SessionContext,
analyzed_mdl: Arc<AnalyzedWrenMDL>,
properties: SessionPropertiesRef,
is_local_runtime: bool,
mode: Mode,
) -> Result<SessionContext> {
let session_timezone = properties
.get("x-wren-timezone")
Expand Down Expand Up @@ -88,29 +88,77 @@ pub async fn create_ctx_with_mdl(
.collect::<HashMap<_, _>>(),
);

let new_state = if is_local_runtime {
new_state.with_analyzer_rules(analyze_rule_for_local_runtime(
Arc::clone(&analyzed_mdl),
reset_default_catalog_schema.clone(),
Arc::clone(&properties),
))
// The plan will be executed locally, so apply the default optimizer rules
let new_state = new_state.with_analyzer_rules(mode.get_analyze_rules(
Arc::clone(&analyzed_mdl),
Arc::clone(&reset_default_catalog_schema),
Arc::clone(&properties),
));
let new_state = if let Some(optimize_rules) = mode.get_optimize_rules() {
new_state.with_optimizer_rules(optimize_rules)
} else {
new_state
.with_analyzer_rules(analyze_rule_for_unparsing(
Arc::clone(&analyzed_mdl),
reset_default_catalog_schema.clone(),
Arc::clone(&properties),
))
.with_optimizer_rules(optimize_rule_for_unparsing())
};

let new_state = new_state.with_config(config).build();
let ctx = SessionContext::new_with_state(new_state);
register_table_with_mdl(&ctx, analyzed_mdl.wren_mdl(), properties).await?;
register_table_with_mdl(&ctx, analyzed_mdl.wren_mdl(), properties, mode).await?;
Ok(ctx)
}

/// Execution mode for Wren engine.
#[derive(Debug)]
pub enum Mode {
/// Local runtime mode, used for executing queries by DataFusion directly.
LocalRuntime,
/// Unparse mode, used for generating SQL statements.
/// This mode is used to generate SQL statements that can be executed in other SQL engines.
Unparse,
/// Permission analyze mode, used for analyzing if the error is caused by permission denied.
/// It's only be used when an error is raised during Unparse mode.
PermissionAnalyze,
}

impl Mode {
pub fn get_analyze_rules(
&self,
analyzed_mdl: Arc<AnalyzedWrenMDL>,
session_state_ref: SessionStateRef,
properties: SessionPropertiesRef,
) -> Vec<Arc<dyn AnalyzerRule + Send + Sync>> {
match self {
Mode::LocalRuntime => analyze_rule_for_local_runtime(
Arc::clone(&analyzed_mdl),
Arc::clone(&session_state_ref),
Arc::clone(&properties),
),
Mode::Unparse => analyze_rule_for_unparsing(
Arc::clone(&analyzed_mdl),
Arc::clone(&session_state_ref),
Arc::clone(&properties),
),
Mode::PermissionAnalyze => analyze_rule_for_permission(
Arc::clone(&analyzed_mdl),
Arc::clone(&session_state_ref),
Arc::clone(&properties),
),
}
}

pub fn get_optimize_rules(
&self,
) -> Option<Vec<Arc<dyn OptimizerRule + Send + Sync>>> {
match self {
Mode::LocalRuntime => None,
Mode::Unparse => Some(optimize_rule_for_unparsing()),
Mode::PermissionAnalyze => Some(vec![]),
}
}

pub fn is_permission_analyze(&self) -> bool {
matches!(self, Mode::PermissionAnalyze)
}
}

// Analyzer rules for local runtime
fn analyze_rule_for_local_runtime(
analyzed_mdl: Arc<AnalyzedWrenMDL>,
Expand Down Expand Up @@ -227,10 +275,32 @@ fn optimize_rule_for_unparsing() -> Vec<Arc<dyn OptimizerRule + Send + Sync>> {
]
}

fn analyze_rule_for_permission(
analyzed_mdl: Arc<AnalyzedWrenMDL>,
session_state_ref: SessionStateRef,
properties: SessionPropertiesRef,
) -> Vec<Arc<dyn AnalyzerRule + Send + Sync>> {
vec![
// To align the lastest change in datafusion, apply this this rule first.
Arc::new(ExpandWildcardRule::new()),
// expand the view should be the first rule
Arc::new(ExpandWrenViewRule::new(
Arc::clone(&analyzed_mdl),
Arc::clone(&session_state_ref),
)),
Arc::new(ModelAnalyzeRule::new(
Arc::clone(&analyzed_mdl),
Arc::clone(&session_state_ref),
Arc::clone(&properties),
)),
]
}

pub async fn register_table_with_mdl(
ctx: &SessionContext,
wren_mdl: Arc<WrenMDL>,
properties: SessionPropertiesRef,
mode: Mode,
) -> Result<()> {
let catalog = MemoryCatalogProvider::new();
let schema = MemorySchemaProvider::new();
Expand All @@ -239,7 +309,7 @@ pub async fn register_table_with_mdl(
ctx.register_catalog(&wren_mdl.manifest.catalog, Arc::new(catalog));

for model in wren_mdl.manifest.models.iter() {
let table = WrenDataSource::new(Arc::clone(model), &properties)?;
let table = WrenDataSource::new(Arc::clone(model), &properties, &mode)?;
ctx.register_table(
TableReference::full(wren_mdl.catalog(), wren_mdl.schema(), model.name()),
Arc::new(table),
Expand All @@ -262,12 +332,17 @@ pub struct WrenDataSource {
}

impl WrenDataSource {
pub fn new(model: Arc<Model>, properties: &SessionPropertiesRef) -> Result<Self> {
pub fn new(
model: Arc<Model>,
properties: &SessionPropertiesRef,
mode: &Mode,
) -> Result<Self> {
let available_columns = model
.get_physical_columns()
.iter()
.map(|column| {
if validate_clac_rule(column, properties)? {
if mode.is_permission_analyze() || validate_clac_rule(column, properties)?
{
Ok(Some(Arc::clone(column)))
} else {
Ok(None)
Expand Down
Loading