Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
with:
rust-version: stable
- name: Run tests (excluding doctests)
run: cargo test --lib --tests --bins
run: RUST_MIN_STACK=8388608 cargo test --lib --tests --bins
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid the following test error, add the min stack size of rust:

thread 'mdl::test::test_plan_calculation_without_unnamed_subquery' has overflowed its stack.

- name: Verify Working Directory Clean
run: git diff --exit-code

Expand All @@ -77,7 +77,7 @@ jobs:
- name: Run tests (excluding doctests)
shell: bash
run: |
cargo test --lib --tests --bins
RUST_MIN_STACK=8388608 cargo test --lib --tests --bins

macos:
name: cargo test (macos)
Expand All @@ -90,7 +90,7 @@ jobs:
uses: ./.github/actions/rust/setup-macos-builder
- name: Run tests (excluding doctests)
shell: bash
run: cargo test --lib --tests --bins
run: RUST_MIN_STACK=8388608 cargo test --lib --tests --bins

macos-aarch64:
name: cargo test (macos-aarch64)
Expand All @@ -103,7 +103,7 @@ jobs:
uses: ./.github/actions/rust/setup-macos-aarch64-builder
- name: Run tests (excluding doctests)
shell: bash
run: cargo test --lib --tests --bins
run: RUST_MIN_STACK=8388608 cargo test --lib --tests --bins

check-fmt:
name: Check cargo fmt
Expand Down
83 changes: 69 additions & 14 deletions wren-core/core/src/logical_plan/analyze/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use std::sync::Arc;

use datafusion::arrow::datatypes::Field;
use datafusion::common::{
internal_err, plan_err, Column, DFSchema, DFSchemaRef, TableReference,
internal_datafusion_err, internal_err, plan_err, Column, DFSchema, DFSchemaRef,
TableReference,
};
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::expr::WildcardOptions;
Expand Down Expand Up @@ -198,21 +199,29 @@ impl ModelPlanNodeBuilder {
} else {
merge_graph(&mut self.directed_graph, column_graph)?;
if self.is_contain_calculation_source(&qualified_column) {
collect_partial_model_plan(
collect_partial_model_plan_for_calculation(
Arc::clone(&self.analyzed_wren_mdl),
Arc::clone(&self.session_state),
&qualified_column,
&mut self.model_required_fields,
)?;
}
// Collect the column for building the partial model for the related model.
collect_partial_model_required_fields(
Arc::clone(&self.analyzed_wren_mdl),
Arc::clone(&self.session_state),
&qualified_column,
&mut self.model_required_fields,
)?;
self.required_exprs_buffer
.insert(OrdExpr::new(expr.clone()));
let _ = collect_model_required_fields(
&qualified_column,
// Collect the column for building the source model
collect_model_required_fields(
Arc::clone(&self.analyzed_wren_mdl),
Arc::clone(&self.session_state),
&qualified_column,
&mut self.model_required_fields,
);
)?;
}
} else {
let expr_plan = get_remote_column_exp(
Expand Down Expand Up @@ -294,7 +303,6 @@ impl ModelPlanNodeBuilder {
.get(&model_ref)
.map(|c| c.iter().cloned().map(|c| c.expr).collect())
.unwrap_or_default();

let mut calculate_iter = self.required_calculation.iter();
let source_chain =
if !source_required_fields.is_empty() || required_fields.is_empty() {
Expand Down Expand Up @@ -438,18 +446,18 @@ impl ModelPlanNodeBuilder {
let mut partial_model_required_fields = HashMap::new();

if self.is_contain_calculation_source(qualified_column) {
collect_partial_model_plan(
collect_partial_model_plan_for_calculation(
Arc::clone(&self.analyzed_wren_mdl),
Arc::clone(&self.session_state),
qualified_column,
&mut partial_model_required_fields,
)?;
}

collect_model_required_fields(
qualified_column,
collect_partial_model_required_fields(
Arc::clone(&self.analyzed_wren_mdl),
Arc::clone(&self.session_state),
qualified_column,
&mut partial_model_required_fields,
)?;

Expand Down Expand Up @@ -505,7 +513,9 @@ fn is_required_column(expr: &Expr, name: &str) -> bool {
}
}

fn collect_partial_model_plan(
/// Collect the fields for the calculation plan.
/// It collects the only calculated fields for the calculation plan.
fn collect_partial_model_plan_for_calculation(
analyzed_wren_mdl: Arc<AnalyzedWrenMDL>,
session_state_ref: SessionStateRef,
qualified_column: &Column,
Expand Down Expand Up @@ -547,11 +557,56 @@ fn collect_partial_model_plan(
Ok(())
}

fn collect_model_required_fields(
/// Collect the required fields for the partial model used by another model throguh the relationship.
/// It collects the non-calculated fields for the he partial model used by another model.
fn collect_partial_model_required_fields(
analyzed_wren_mdl: Arc<AnalyzedWrenMDL>,
session_state_ref: SessionStateRef,
qualified_column: &Column,
required_fields: &mut HashMap<TableReference, BTreeSet<OrdExpr>>,
) -> Result<()> {
let Some(set) = analyzed_wren_mdl
.lineage()
.required_fields_map
.get(qualified_column)
else {
return plan_err!("Required fields not found for {}", qualified_column);
};

for c in set {
let Some(relation_ref) = &c.relation else {
return plan_err!("Source dataset not found for {}", c);
};
let Some(ColumnReference { dataset, column }) =
analyzed_wren_mdl.wren_mdl().get_column_reference(c)
else {
return plan_err!("Column reference not found for {}", c);
};
if !column.is_calculated {
let expr = create_wren_expr_for_model(
&c.name,
dataset.try_as_model().ok_or_else(|| {
internal_datafusion_err!("Only support model as source dataset")
})?,
Arc::clone(&session_state_ref),
)?;
required_fields
.entry(relation_ref.clone())
.or_default()
.insert(OrdExpr::with_column(expr, Arc::clone(&column)));
}
}
Ok(())
}

/// Collect the required field for the model plan.
/// It collect the calculated fields for building the calculation plan.
/// It collects the non-calculated source column for building the model source plan.
fn collect_model_required_fields(
analyzed_wren_mdl: Arc<AnalyzedWrenMDL>,
session_state_ref: SessionStateRef,
model_required_fields: &mut HashMap<TableReference, BTreeSet<OrdExpr>>,
qualified_column: &Column,
required_fields: &mut HashMap<TableReference, BTreeSet<OrdExpr>>,
) -> Result<()> {
let Some(set) = analyzed_wren_mdl
.lineage()
Expand Down Expand Up @@ -591,7 +646,7 @@ fn collect_model_required_fields(
}
.alias(column.name.clone());
debug!("Required Calculated field: {}", &expr_plan);
model_required_fields
required_fields
.entry(relation_ref.clone())
.or_default()
.insert(OrdExpr::with_column(expr_plan, column));
Expand All @@ -603,7 +658,7 @@ fn collect_model_required_fields(
Arc::clone(&session_state_ref),
)?;
debug!("Required field: {}", &expr_plan);
model_required_fields
required_fields
.entry(relation_ref.clone())
.or_default()
.insert(OrdExpr::with_column(expr_plan, column));
Expand Down
58 changes: 27 additions & 31 deletions wren-core/core/src/logical_plan/analyze/relation_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ use crate::mdl::Dataset;
use crate::mdl::{AnalyzedWrenMDL, SessionStateRef};
use crate::{mdl, DataFusionError};
use datafusion::common::alias::AliasGenerator;
use datafusion::common::TableReference;
use datafusion::common::{
internal_err, not_impl_err, plan_err, DFSchema, DFSchemaRef, Result,
};
use datafusion::common::{plan_datafusion_err, TableReference};
use datafusion::logical_expr::{
col, Expr, Extension, LogicalPlan, LogicalPlanBuilder, SubqueryAlias,
UserDefinedLogicalNodeCore,
Expand Down Expand Up @@ -93,37 +93,33 @@ impl RelationChain {
};
match target {
Dataset::Model(target_model) => {
let node = if fields.iter().any(|e| {
e.column.is_some() && e.column.clone().unwrap().is_calculated
}) {
let schema = create_schema(
fields.iter().filter_map(|e| e.column.clone()).collect(),
)?;
let plan = ModelPlanNode::new(
Arc::clone(target_model),
fields.iter().cloned().map(|c| c.expr).collect(),
None,
Arc::clone(&analyzed_wren_mdl),
Arc::clone(&session_state_ref),
Arc::clone(&properties),
)?;
let schema = create_schema(
fields
.iter()
.map(|e| {
e.column.clone().ok_or_else(|| {
plan_datafusion_err!(
"Required field {:?} has no physical column",
e.expr
)
})
})
.collect::<Result<_>>()?,
)?;
let exprs = fields.iter().cloned().map(|c| c.expr).collect();
let plan = ModelPlanNode::new(
Arc::clone(target_model),
exprs,
None,
Arc::clone(&analyzed_wren_mdl),
Arc::clone(&session_state_ref),
Arc::clone(&properties),
)?;

let df_schema =
DFSchemaRef::from(DFSchema::try_from(schema).unwrap());
LogicalPlan::Extension(Extension {
node: Arc::new(PartialModelPlanNode::new(plan, df_schema)),
})
} else {
LogicalPlan::Extension(Extension {
node: Arc::new(ModelSourceNode::new(
Arc::clone(target_model),
fields.iter().cloned().map(|c| c.expr).collect(),
Arc::clone(&analyzed_wren_mdl),
Arc::clone(&session_state_ref),
None,
)?),
})
};
let df_schema = DFSchemaRef::from(DFSchema::try_from(schema)?);
let node = LogicalPlan::Extension(Extension {
node: Arc::new(PartialModelPlanNode::new(plan, df_schema)),
});
relation_chain = RelationChain::Chain(
node,
link.join_type,
Expand Down
Loading