Skip to content

chore: support collect statistics of multi join expr in merge into #13511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 167 additions & 118 deletions src/query/service/src/interpreters/interpreter_merge_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use common_catalog::table::TableExt;
use common_exception::ErrorCode;
use common_exception::Result;
use common_expression::ConstantFolder;
use common_expression::DataBlock;
use common_expression::DataSchema;
use common_expression::DataSchemaRef;
use common_expression::FieldIndex;
Expand All @@ -48,6 +47,7 @@ use common_sql::plans::BoundColumnRef;
use common_sql::plans::ConstantExpr;
use common_sql::plans::EvalScalar;
use common_sql::plans::FunctionCall;
use common_sql::plans::JoinType;
use common_sql::plans::MergeInto as MergePlan;
use common_sql::plans::Plan;
use common_sql::plans::RelOperator;
Expand Down Expand Up @@ -86,6 +86,33 @@ pub struct MergeIntoInterpreter {
plan: MergePlan,
}

struct MergeStyleJoin<'a> {
source_conditions: &'a [ScalarExpr],
target_conditions: &'a [ScalarExpr],
source_sexpr: &'a SExpr,
target_sexpr: &'a SExpr,
}

impl MergeStyleJoin<'_> {
pub fn new(join: &SExpr) -> MergeStyleJoin {
let join_op = match join.plan() {
RelOperator::Join(j) => j,
_ => unreachable!(),
};
assert!(matches!(join_op.join_type, JoinType::Right));
let source_conditions = &join_op.right_conditions;
let target_conditions = &join_op.left_conditions;
let source_sexpr = join.child(1).unwrap();
let target_sexpr = join.child(0).unwrap();
MergeStyleJoin {
source_conditions,
target_conditions,
source_sexpr,
target_sexpr,
}
}
}

impl MergeIntoInterpreter {
pub fn try_create(ctx: Arc<QueryContext>, plan: MergePlan) -> Result<InterpreterPtr> {
Ok(Arc::new(MergeIntoInterpreter { ctx, plan }))
Expand Down Expand Up @@ -448,40 +475,90 @@ impl MergeIntoInterpreter {
// EvalScalar(source_join_side_expr)
// \
// SourcePlan

let source_plan = join.child(1)?;
let join_op = match join.plan() {
RelOperator::Join(j) => j,
_ => unreachable!(),
};
if join_op.left_conditions.len() != 1 || join_op.right_conditions.len() != 1 {
let m_join = MergeStyleJoin::new(join);
let mut eval_scalar_items = Vec::with_capacity(m_join.source_conditions.len());
let mut min_max_binding = Vec::with_capacity(m_join.source_conditions.len() * 2);
let mut min_max_scalar_items = Vec::with_capacity(m_join.source_conditions.len() * 2);
if m_join.source_conditions.is_empty() {
return Ok(Box::new(join.clone()));
}
let source_side_expr = &join_op.right_conditions[0];
let target_side_expr = &join_op.left_conditions[0];
for source_side_expr in m_join.source_conditions {
// eval source side join expr
let index = metadata
.write()
.add_derived_column("".to_string(), source_side_expr.data_type()?);
let evaled = ScalarExpr::BoundColumnRef(BoundColumnRef {
span: None,
column: ColumnBindingBuilder::new(
"".to_string(),
index,
Box::new(source_side_expr.data_type()?),
Visibility::Visible,
)
.build(),
});
eval_scalar_items.push(ScalarItem {
scalar: source_side_expr.clone(),
index,
});

// eval min/max of source side join expr
let min_display_name = format!("min({:?})", source_side_expr);
let max_display_name = format!("max({:?})", source_side_expr);
let min_index = metadata
.write()
.add_derived_column(min_display_name.clone(), source_side_expr.data_type()?);
let max_index = metadata
.write()
.add_derived_column(max_display_name.clone(), source_side_expr.data_type()?);
let min_binding = ColumnBindingBuilder::new(
min_display_name.clone(),
min_index,
Box::new(source_side_expr.data_type()?),
Visibility::Visible,
)
.build();
let max_binding = ColumnBindingBuilder::new(
max_display_name.clone(),
max_index,
Box::new(source_side_expr.data_type()?),
Visibility::Visible,
)
.build();
min_max_binding.push(min_binding);
min_max_binding.push(max_binding);
let min = ScalarItem {
scalar: ScalarExpr::AggregateFunction(AggregateFunction {
func_name: "min".to_string(),
distinct: false,
params: vec![],
args: vec![evaled.clone()],
return_type: Box::new(source_side_expr.data_type()?),
display_name: min_display_name.clone(),
}),
index: min_index,
};
let max = ScalarItem {
scalar: ScalarExpr::AggregateFunction(AggregateFunction {
func_name: "max".to_string(),
distinct: false,
params: vec![],
args: vec![evaled],
return_type: Box::new(source_side_expr.data_type()?),
display_name: max_display_name.clone(),
}),
index: max_index,
};
min_max_scalar_items.push(min);
min_max_scalar_items.push(max);
}

let group_item = eval_scalar_items[0].clone();

// eval source side join expr
let source_side_join_expr_index = metadata.write().add_derived_column(
"source_side_join_expr".to_string(),
source_side_expr.data_type()?,
);
let source_side_join_expr_binding = ColumnBindingBuilder::new(
"source_side_join_expr".to_string(),
source_side_join_expr_index,
Box::new(source_side_expr.data_type()?),
Visibility::Visible,
)
.build();
let evaled_source_side_join_expr = ScalarExpr::BoundColumnRef(BoundColumnRef {
span: None,
column: source_side_join_expr_binding.clone(),
});
let eval_source_side_join_expr_op = EvalScalar {
items: vec![ScalarItem {
scalar: source_side_expr.clone(),
index: source_side_join_expr_index,
}],
items: eval_scalar_items,
};
let source_plan = m_join.source_sexpr;
let eval_target_side_condition_sexpr = if let RelOperator::Exchange(_) = source_plan.plan()
{
// there is another row_number operator here
Expand All @@ -496,57 +573,13 @@ impl MergeIntoInterpreter {
)
};

// eval min/max of source side join expr
let min_display_name = format!("min({:?})", source_side_expr);
let max_display_name = format!("max({:?})", source_side_expr);
let min_index = metadata
.write()
.add_derived_column(min_display_name.clone(), source_side_expr.data_type()?);
let max_index = metadata
.write()
.add_derived_column(max_display_name.clone(), source_side_expr.data_type()?);
let mut bind_context = Box::new(BindContext::new());
let min_binding = ColumnBindingBuilder::new(
min_display_name.clone(),
min_index,
Box::new(source_side_expr.data_type()?),
Visibility::Visible,
)
.build();
let max_binding = ColumnBindingBuilder::new(
max_display_name.clone(),
max_index,
Box::new(source_side_expr.data_type()?),
Visibility::Visible,
)
.build();
bind_context.columns = vec![min_binding.clone(), max_binding.clone()];
let min = ScalarItem {
scalar: ScalarExpr::AggregateFunction(AggregateFunction {
func_name: "min".to_string(),
distinct: false,
params: vec![],
args: vec![evaled_source_side_join_expr.clone()],
return_type: Box::new(source_side_expr.data_type()?),
display_name: min_display_name.clone(),
}),
index: min_index,
};
let max = ScalarItem {
scalar: ScalarExpr::AggregateFunction(AggregateFunction {
func_name: "max".to_string(),
distinct: false,
params: vec![],
args: vec![evaled_source_side_join_expr],
return_type: Box::new(source_side_expr.data_type()?),
display_name: max_display_name.clone(),
}),
index: max_index,
};
bind_context.columns = min_max_binding;

let agg_partial_op = Aggregate {
mode: AggregateMode::Partial,
group_items: vec![],
aggregate_functions: vec![min.clone(), max.clone()],
group_items: vec![group_item.clone()],
aggregate_functions: min_max_scalar_items.clone(),
from_distinct: false,
limit: None,
grouping_sets: None,
Expand All @@ -557,8 +590,8 @@ impl MergeIntoInterpreter {
);
let agg_final_op = Aggregate {
mode: AggregateMode::Final,
group_items: vec![],
aggregate_functions: vec![min.clone(), max.clone()],
group_items: vec![group_item],
aggregate_functions: min_max_scalar_items,
from_distinct: false,
limit: None,
grouping_sets: None,
Expand All @@ -577,45 +610,61 @@ impl MergeIntoInterpreter {
let stream: SendableDataBlockStream = interpreter.execute(ctx.clone()).await?;
let blocks = stream.collect::<Result<Vec<_>>>().await?;

debug_assert_eq!(blocks.len(), 1);
let block = &blocks[0];
debug_assert_eq!(block.num_columns(), 2);

let get_scalar_expr = |block: &DataBlock, index: usize| {
let block_entry = &block.columns()[index];
let scalar = match &block_entry.value {
common_expression::Value::Scalar(scalar) => scalar.clone(),
common_expression::Value::Column(column) => {
debug_assert_eq!(column.len(), 1);
let value_ref = column.index(0).unwrap();
value_ref.to_owned()
}
};
ScalarExpr::ConstantExpr(ConstantExpr {
span: None,
value: scalar,
})
};

let min_scalar = get_scalar_expr(block, 0);
let max_scalar = get_scalar_expr(block, 1);

// 2. build filter and push down to target side
let gte_min = ScalarExpr::FunctionCall(FunctionCall {
span: None,
func_name: "gte".to_string(),
params: vec![],
arguments: vec![target_side_expr.clone(), min_scalar],
});
let lte_max = ScalarExpr::FunctionCall(FunctionCall {
span: None,
func_name: "lte".to_string(),
params: vec![],
arguments: vec![target_side_expr.clone(), max_scalar],
});

let filters = vec![gte_min, lte_max];
let mut target_plan = join.child(0)?.clone();
let mut filters = Vec::with_capacity(m_join.target_conditions.len());

for (i, target_side_expr) in m_join.target_conditions.iter().enumerate() {
let mut filter_parts = vec![];
for block in blocks.iter() {
let block = block.convert_to_full();
let min_column = block.get_by_offset(i * 2).value.as_column().unwrap();
let max_column = block.get_by_offset(i * 2 + 1).value.as_column().unwrap();
for (min_scalar, max_scalar) in min_column.iter().zip(max_column.iter()) {
let gte_min = ScalarExpr::FunctionCall(FunctionCall {
span: None,
func_name: "gte".to_string(),
params: vec![],
arguments: vec![
target_side_expr.clone(),
ScalarExpr::ConstantExpr(ConstantExpr {
span: None,
value: min_scalar.to_owned(),
}),
],
});
let lte_max = ScalarExpr::FunctionCall(FunctionCall {
span: None,
func_name: "lte".to_string(),
params: vec![],
arguments: vec![
target_side_expr.clone(),
ScalarExpr::ConstantExpr(ConstantExpr {
span: None,
value: max_scalar.to_owned(),
}),
],
});
let and = ScalarExpr::FunctionCall(FunctionCall {
span: None,
func_name: "and".to_string(),
params: vec![],
arguments: vec![gte_min, lte_max],
});
filter_parts.push(and);
}
}
let mut or = filter_parts[0].clone();
for filter_part in filter_parts.iter().skip(1) {
or = ScalarExpr::FunctionCall(FunctionCall {
span: None,
func_name: "or".to_string(),
params: vec![],
arguments: vec![or, filter_part.clone()],
});
}
filters.push(or);
}
let mut target_plan = m_join.target_sexpr.clone();
Self::push_down_filters(&mut target_plan, &filters)?;
let new_sexpr =
join.replace_children(vec![Arc::new(target_plan), Arc::new(source_plan.clone())]);
Expand Down