Skip to content

Commit

Permalink
Allow to invoke a calculation belong to another models (#676)
Browse files Browse the repository at this point in the history
* refactor the planner

* tmp

* introduce the calculation inovke another calculation

* remove all dbg!

* add example

* fix fmt and clippy

* enhance the example, todo case and tests
  • Loading branch information
goldmedal committed Jul 18, 2024
1 parent 3cbcde9 commit 16cc917
Show file tree
Hide file tree
Showing 10 changed files with 1,000 additions and 356 deletions.
2 changes: 1 addition & 1 deletion wren-modeling-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ version = "0.1.0"

[workspace.dependencies]
async-trait = "0.1.80"
datafusion = { version = "39.0.0" }
datafusion = { version = "39.0.0", features = ["backtrace"] }
env_logger = "0.11.3"
log = { version = "0.4.14" }
petgraph = "0.6.5"
Expand Down
3 changes: 3 additions & 0 deletions wren-modeling-rs/core/src/logical_plan/analyze/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
pub mod plan;
mod relation_chain;
pub mod rule;

pub use relation_chain::RelationChain;
702 changes: 369 additions & 333 deletions wren-modeling-rs/core/src/logical_plan/analyze/plan.rs

Large diffs are not rendered by default.

250 changes: 250 additions & 0 deletions wren-modeling-rs/core/src/logical_plan/analyze/relation_chain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
use crate::logical_plan::analyze::plan::{
CalculationPlanNode, ModelPlanNode, ModelSourceNode, OrdExpr, PartialModelPlanNode,
};
use crate::logical_plan::analyze::relation_chain::RelationChain::Start;
use crate::logical_plan::analyze::rule::ModelGenerationRule;
use crate::logical_plan::utils::create_schema;
use crate::mdl;
use crate::mdl::lineage::DatasetLink;
use crate::mdl::manifest::JoinType;
use crate::mdl::{AnalyzedWrenMDL, Dataset};
use datafusion::catalog::TableReference;
use datafusion::common::{internal_err, not_impl_err, plan_err, DFSchema, DFSchemaRef};
use datafusion::logical_expr::{
col, Expr, Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNodeCore,
};
use petgraph::graph::NodeIndex;
use petgraph::Graph;
use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;

/// RelationChain is a chain of models that are connected by the relationship.
/// The chain is used to generate the join plan for the model.
/// The physical layout will be looked like:
/// (((Model3, Model2), Model1), Nil)
#[derive(Eq, PartialEq, Debug, Hash, Clone)]
pub enum RelationChain {
Chain(LogicalPlan, JoinType, String, Box<RelationChain>),
Start(LogicalPlan),
}

impl RelationChain {
pub(crate) fn source(
dataset: &Dataset,
required_fields: Vec<Expr>,
analyzed_wren_mdl: Arc<AnalyzedWrenMDL>,
) -> datafusion::common::Result<Self> {
match dataset {
Dataset::Model(source_model) => {
Ok(Start(LogicalPlan::Extension(Extension {
node: Arc::new(ModelSourceNode::new(
Arc::clone(source_model),
required_fields,
analyzed_wren_mdl,
None,
)?),
})))
}
_ => {
not_impl_err!("Only support model as source dataset")
}
}
}

pub fn with_chain(
source: Self,
mut start: NodeIndex,
iter: impl Iterator<Item = NodeIndex>,
directed_graph: Graph<Dataset, DatasetLink>,
model_required_fields: &HashMap<TableReference, BTreeSet<OrdExpr>>,
analyzed_wren_mdl: Arc<AnalyzedWrenMDL>,
) -> datafusion::common::Result<Self> {
let mut relation_chain = source;

for next in iter {
let target = directed_graph.node_weight(next).unwrap();
let Some(link_index) = directed_graph.find_edge(start, next) else {
break;
};
let link = directed_graph.edge_weight(link_index).unwrap();
let target_ref = TableReference::full(
analyzed_wren_mdl.wren_mdl().catalog(),
analyzed_wren_mdl.wren_mdl().schema(),
target.name(),
);
let Some(fields) = model_required_fields.get(&target_ref) else {
return plan_err!("Required fields not found for {}", target_ref);
};
match target {
Dataset::Model(target_model) => {
let node = if fields.iter().any(|e| {
e.column.is_some() && e.column.clone().unwrap().is_calculated
}) {
let schema = create_schema(
fields.iter().filter_map(|e| e.column.clone()).collect(),
)?;
let plan = ModelPlanNode::new(
Arc::clone(target_model),
fields.iter().cloned().map(|c| c.expr).collect(),
None,
Arc::clone(&analyzed_wren_mdl),
)?;

let df_schema =
DFSchemaRef::from(DFSchema::try_from(schema).unwrap());
LogicalPlan::Extension(Extension {
node: Arc::new(PartialModelPlanNode::new(plan, df_schema)),
})
} else {
LogicalPlan::Extension(Extension {
node: Arc::new(ModelSourceNode::new(
Arc::clone(target_model),
fields.iter().cloned().map(|c| c.expr).collect(),
Arc::clone(&analyzed_wren_mdl),
None,
)?),
})
};
relation_chain = RelationChain::Chain(
node,
link.join_type,
link.condition.clone(),
Box::new(relation_chain),
);
}
_ => return plan_err!("Only support model as source dataset"),
}
start = next;
}
Ok(relation_chain)
}

pub(crate) fn plan(
&mut self,
rule: ModelGenerationRule,
) -> datafusion::common::Result<Option<LogicalPlan>> {
match self {
RelationChain::Chain(plan, _, condition, ref mut next) => {
let left = rule.generate_model_internal(plan.clone())?.data;
let join_keys: Vec<Expr> = mdl::utils::collect_identifiers(condition)?
.iter()
.cloned()
.map(|c| col(c.flat_name()))
.collect();
let join_condition = join_keys[0].clone().eq(join_keys[1].clone());
let Some(right) = next.plan(rule)? else {
return plan_err!("Nil relation chain");
};
let mut required_exprs = BTreeSet::new();
// collect the output calculated fields
match plan {
LogicalPlan::Extension(plan) => {
if let Some(model_plan) =
plan.node.as_any().downcast_ref::<ModelPlanNode>()
{
UserDefinedLogicalNodeCore::schema(model_plan)
.fields()
.iter()
.map(|field| {
col(format!(
"{}.{}",
model_plan.plan_name,
field.name()
))
})
.for_each(|c| {
required_exprs.insert(OrdExpr::new(c));
});
} else if let Some(model_source_plan) =
plan.node.as_any().downcast_ref::<ModelSourceNode>()
{
UserDefinedLogicalNodeCore::schema(model_source_plan)
.fields()
.iter()
.map(|field| {
col(format!(
"{}.{}",
model_source_plan.model_name,
field.name()
))
})
.for_each(|c| {
required_exprs.insert(OrdExpr::new(c));
});
} else if let Some(calculation_plan) =
plan.node.as_any().downcast_ref::<CalculationPlanNode>()
{
UserDefinedLogicalNodeCore::schema(calculation_plan)
.fields()
.iter()
.map(|field| {
col(format!(
"{}.{}",
calculation_plan.calculation.column.name(),
field.name()
))
})
.for_each(|c| {
required_exprs.insert(OrdExpr::new(c));
});
} else if let Some(partial_model_plan) =
plan.node.as_any().downcast_ref::<PartialModelPlanNode>()
{
UserDefinedLogicalNodeCore::schema(partial_model_plan)
.fields()
.iter()
.map(|field| {
col(format!(
"{}.{}",
partial_model_plan.model_node.plan_name,
field.name()
))
})
.for_each(|c| {
required_exprs.insert(OrdExpr::new(c));
});
} else {
return plan_err!("Invalid extension plan node");
}
}
_ => return internal_err!("Invalid plan node"),
};
// collect the column of the left table
for index in 0..left.schema().fields().len() {
let (Some(table_rf), f) = left.schema().qualified_field(index) else {
return plan_err!("Field not found");
};
let qualified_name = format!("{}.{}", table_rf, f.name());
required_exprs.insert(OrdExpr::new(col(qualified_name)));
}

// collect the column of the right table
for index in 0..right.schema().fields().len() {
let (Some(table_rf), f) = right.schema().qualified_field(index)
else {
return plan_err!("Field not found");
};
let qualified_name = format!("{}.{}", table_rf, f.name());
required_exprs.insert(OrdExpr::new(col(qualified_name)));
}

let required_field: Vec<Expr> = required_exprs
.iter()
.map(|expr| expr.expr.clone())
.collect();

Ok(Some(
LogicalPlanBuilder::from(left)
.join_on(
right,
datafusion::logical_expr::JoinType::Right,
vec![join_condition],
)?
.project(required_field)?
.build()?,
))
}
Start(plan) => Ok(Some(rule.generate_model_internal(plan.clone())?.data)),
}
}
}
36 changes: 25 additions & 11 deletions wren-modeling-rs/core/src/logical_plan/analyze/rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ use datafusion::common::config::ConfigOptions;
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion::common::{plan_err, Result};
use datafusion::logical_expr::logical_plan::tree_node::unwrap_arc;
use datafusion::logical_expr::{col, ident, utils, Extension};
use datafusion::logical_expr::{
col, ident, utils, Extension, UserDefinedLogicalNodeCore,
};
use datafusion::logical_expr::{Expr, Join, LogicalPlan, LogicalPlanBuilder, TableScan};
use datafusion::optimizer::analyzer::AnalyzerRule;
use datafusion::sql::TableReference;

use crate::logical_plan::analyze::plan::{
CalculationPlanNode, ModelPlanNode, ModelSourceNode,
CalculationPlanNode, ModelPlanNode, ModelSourceNode, PartialModelPlanNode,
};
use crate::logical_plan::utils::create_remote_table_source;
use crate::mdl::manifest::Model;
Expand Down Expand Up @@ -225,14 +227,6 @@ impl ModelGenerationRule {
let source_plan = model_plan.relation_chain.clone().plan(
ModelGenerationRule::new(Arc::clone(&self.analyzed_wren_mdl)),
)?;

let model: Arc<Model> = Arc::clone(
&self
.analyzed_wren_mdl
.wren_mdl()
.get_model(&model_plan.model_name)
.expect("Model not found"),
);
let result = match source_plan {
Some(plan) => LogicalPlanBuilder::from(plan)
.project(model_plan.required_exprs.clone())?
Expand All @@ -244,7 +238,7 @@ impl ModelGenerationRule {
// calculated field scope

let alias = LogicalPlanBuilder::from(result)
.alias(model.name.clone())?
.alias(&model_plan.plan_name)?
.build()?;
Ok(Transformed::yes(alias))
} else if let Some(model_plan) =
Expand Down Expand Up @@ -328,6 +322,26 @@ impl ModelGenerationRule {
} else {
return plan_err!("measures should have an alias");
}
} else if let Some(partial_model) = extension
.node
.as_any()
.downcast_ref::<PartialModelPlanNode>(
) {
let plan = LogicalPlan::Extension(Extension {
node: Arc::new(partial_model.model_node.clone()),
});
let source_plan = self.generate_model_internal(plan)?.data;
let projection: Vec<_> = partial_model
.schema()
.fields()
.iter()
.map(|f| col(datafusion::common::Column::from((None, f))))
.collect();
let alias = LogicalPlanBuilder::from(source_plan)
.project(projection)?
.alias(partial_model.model_node.plan_name.clone())?
.build()?;
Ok(Transformed::yes(alias))
} else {
Ok(Transformed::no(LogicalPlan::Extension(extension)))
}
Expand Down
Loading

0 comments on commit 16cc917

Please sign in to comment.