Skip to content

Commit

Permalink
feat: convert in list to subquery
Browse files Browse the repository at this point in the history
  • Loading branch information
xudong963 committed Oct 23, 2023
1 parent 37b8eee commit aa827c0
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 133 deletions.
6 changes: 6 additions & 0 deletions src/query/settings/src/settings_default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ impl DefaultSettings {
possible_values: None,
display_in_show_settings: true,
}),
("inlist_to_subquery", DefaultSettingValue {
value: UserSettingValue::UInt64(0),
desc: "Converts IN list to subquery.",
possible_values: None,
display_in_show_settings: true,
}),
("disable_join_reorder", DefaultSettingValue {
value: UserSettingValue::UInt64(0),
desc: "Disable join reorder optimization.",
Expand Down
4 changes: 4 additions & 0 deletions src/query/settings/src/settings_getter_setter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ impl Settings {
Ok(self.try_get_u64("enable_cbo")? != 0)
}

pub fn get_inlist_to_subquery(&self) -> Result<bool> {
Ok(self.try_get_u64("inlist_to_subquery")? != 0)
}

pub fn get_disable_join_reorder(&self) -> Result<bool> {
Ok(self.try_get_u64("disable_join_reorder")? != 0)
}
Expand Down
1 change: 1 addition & 0 deletions src/query/sql/src/planner/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,5 @@ pub use scalar::ScalarBinder;
pub use scalar_common::*;
pub use scalar_visitor::*;
pub use table::parse_result_scan_args;
pub use values::bind_values;
pub use window::WindowOrderByInfo;
280 changes: 150 additions & 130 deletions src/query/sql/src/planner/binder/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::collections::HashMap;
use std::sync::Arc;

use common_ast::ast::Expr as AExpr;
use common_catalog::table_context::TableContext;
use common_exception::ErrorCode;
use common_exception::Result;
use common_exception::Span;
Expand All @@ -37,6 +38,8 @@ use crate::plans::ConstantTableScan;
use crate::BindContext;
use crate::Binder;
use crate::ColumnBindingBuilder;
use crate::MetadataRef;
use crate::NameResolutionContext;
use crate::ScalarBinder;
use crate::Visibility;

Expand All @@ -48,149 +51,166 @@ impl Binder {
span: Span,
values: &Vec<Vec<AExpr>>,
) -> Result<(SExpr, BindContext)> {
if values.is_empty() {
return Err(ErrorCode::SemanticError(
"Values lists must have at least one row".to_string(),
)
.set_span(span));
}
let same_length = values.windows(2).all(|v| v[0].len() == v[1].len());
if !same_length {
return Err(ErrorCode::SemanticError(
"Values lists must all be the same length".to_string(),
)
.set_span(span));
}

let num_rows = values.len();
let num_cols = values[0].len();

// assigns default column names col0, col1, etc.
let names = (0..num_cols)
.map(|i| format!("col{}", i))
.collect::<Vec<_>>();

let mut scalar_binder = ScalarBinder::new(
bind_context,
bind_values(
self.ctx.clone(),
&self.name_resolution_ctx,
self.metadata.clone(),
&[],
HashMap::new(),
Box::new(IndexMap::new()),
);

let mut col_scalars = vec![Vec::with_capacity(values.len()); num_cols];
let mut common_types: Vec<Option<DataType>> = vec![None; num_cols];

for row_values in values.iter() {
for (i, value) in row_values.iter().enumerate() {
let (scalar, data_type) = scalar_binder.bind(value).await?;
col_scalars[i].push((scalar, data_type.clone()));

// Get the common data type for each columns.
match &common_types[i] {
Some(common_type) => {
if common_type != &data_type {
let new_common_type = common_super_type(
common_type.clone(),
data_type.clone(),
&BUILTIN_FUNCTIONS.default_cast_rules,
);
if new_common_type.is_none() {
return Err(ErrorCode::SemanticError(format!(
"{} and {} don't have common data type",
common_type, data_type
))
.set_span(span));
}
common_types[i] = new_common_type;
bind_context,
span,
values,
)
.await
}
}

pub async fn bind_values(
ctx: Arc<dyn TableContext>,
name_resolution_ctx: &NameResolutionContext,
metadata: MetadataRef,
bind_context: &mut BindContext,
span: Span,
values: &Vec<Vec<AExpr>>,
) -> Result<(SExpr, BindContext)> {
if values.is_empty() {
return Err(ErrorCode::SemanticError(
"Values lists must have at least one row".to_string(),
)
.set_span(span));
}
let same_length = values.windows(2).all(|v| v[0].len() == v[1].len());
if !same_length {
return Err(ErrorCode::SemanticError(
"Values lists must all be the same length".to_string(),
)
.set_span(span));
}

let num_rows = values.len();
let num_cols = values[0].len();

// assigns default column names col0, col1, etc.
let names = (0..num_cols)
.map(|i| format!("col{}", i))
.collect::<Vec<_>>();

let mut scalar_binder = ScalarBinder::new(
bind_context,
ctx.clone(),
name_resolution_ctx,
metadata.clone(),
&[],
HashMap::new(),
Box::new(IndexMap::new()),
);

let mut col_scalars = vec![Vec::with_capacity(values.len()); num_cols];
let mut common_types: Vec<Option<DataType>> = vec![None; num_cols];

for row_values in values.iter() {
for (i, value) in row_values.iter().enumerate() {
let (scalar, data_type) = scalar_binder.bind(value).await?;
col_scalars[i].push((scalar, data_type.clone()));

// Get the common data type for each columns.
match &common_types[i] {
Some(common_type) => {
if common_type != &data_type {
let new_common_type = common_super_type(
common_type.clone(),
data_type.clone(),
&BUILTIN_FUNCTIONS.default_cast_rules,
);
if new_common_type.is_none() {
return Err(ErrorCode::SemanticError(format!(
"{} and {} don't have common data type",
common_type, data_type
))
.set_span(span));
}
common_types[i] = new_common_type;
}
None => {
common_types[i] = Some(data_type);
}
}
None => {
common_types[i] = Some(data_type);
}
}
}
}

let mut value_fields = Vec::with_capacity(names.len());
for (name, common_type) in names.into_iter().zip(common_types.into_iter()) {
let value_field = DataField::new(&name, common_type.unwrap());
value_fields.push(value_field);
}
let value_schema = DataSchema::new(value_fields);

let input = DataBlock::empty();
let func_ctx = self.ctx.get_function_context()?;
let evaluator = Evaluator::new(&input, &func_ctx, &BUILTIN_FUNCTIONS);

// use values to build columns
let mut value_columns = Vec::with_capacity(col_scalars.len());
for (scalars, value_field) in col_scalars.iter().zip(value_schema.fields().iter()) {
let mut builder =
ColumnBuilder::with_capacity(value_field.data_type(), col_scalars.len());
for (scalar, value_type) in scalars {
let scalar = if value_type != value_field.data_type() {
wrap_cast_scalar(scalar, value_type, value_field.data_type())?
} else {
scalar.clone()
};
let expr = scalar.as_expr()?.project_column_ref(|col| {
value_schema.index_of(&col.index.to_string()).unwrap()
});
let result = evaluator.run(&expr)?;

match result.as_scalar() {
Some(val) => {
builder.push(val.as_ref());
}
None => {
return Err(ErrorCode::SemanticError(format!(
"Value must be a scalar, but get {}",
result
))
.set_span(span));
}
let mut value_fields = Vec::with_capacity(names.len());
for (name, common_type) in names.into_iter().zip(common_types.into_iter()) {
let value_field = DataField::new(&name, common_type.unwrap());
value_fields.push(value_field);
}
let value_schema = DataSchema::new(value_fields);

let input = DataBlock::empty();
let func_ctx = ctx.get_function_context()?;
let evaluator = Evaluator::new(&input, &func_ctx, &BUILTIN_FUNCTIONS);

// use values to build columns
let mut value_columns = Vec::with_capacity(col_scalars.len());
for (scalars, value_field) in col_scalars.iter().zip(value_schema.fields().iter()) {
let mut builder = ColumnBuilder::with_capacity(value_field.data_type(), col_scalars.len());
for (scalar, value_type) in scalars {
let scalar = if value_type != value_field.data_type() {
wrap_cast_scalar(scalar, value_type, value_field.data_type())?
} else {
scalar.clone()
};
let expr = scalar
.as_expr()?
.project_column_ref(|col| value_schema.index_of(&col.index.to_string()).unwrap());
let result = evaluator.run(&expr)?;

match result.as_scalar() {
Some(val) => {
builder.push(val.as_ref());
}
None => {
return Err(ErrorCode::SemanticError(format!(
"Value must be a scalar, but get {}",
result
))
.set_span(span));
}
}
value_columns.push(builder.build());
}
value_columns.push(builder.build());
}

// add column bindings
let mut columns = ColumnSet::new();
let mut fields = Vec::with_capacity(values.len());
for value_field in value_schema.fields() {
let index = self
.metadata
.write()
.add_derived_column(value_field.name().clone(), value_field.data_type().clone());
columns.insert(index);

let column_binding = ColumnBindingBuilder::new(
value_field.name().clone(),
index,
Box::new(value_field.data_type().clone()),
Visibility::Visible,
)
.build();
bind_context.add_column_binding(column_binding);

let field = DataField::new(&index.to_string(), value_field.data_type().clone());
fields.push(field);
// add column bindings
let mut columns = ColumnSet::new();
let mut fields = Vec::with_capacity(values.len());
for value_field in value_schema.fields() {
let index = metadata
.write()
.add_derived_column(value_field.name().clone(), value_field.data_type().clone());
columns.insert(index);

let column_binding = ColumnBindingBuilder::new(
value_field.name().clone(),
index,
Box::new(value_field.data_type().clone()),
Visibility::Visible,
)
.build();
bind_context.add_column_binding(column_binding);

let field = DataField::new(&index.to_string(), value_field.data_type().clone());
fields.push(field);
}
let schema = DataSchemaRefExt::create(fields);

let s_expr = SExpr::create_leaf(Arc::new(
ConstantTableScan {
values: value_columns,
num_rows,
schema,
columns,
}
let schema = DataSchemaRefExt::create(fields);

let s_expr = SExpr::create_leaf(Arc::new(
ConstantTableScan {
values: value_columns,
num_rows,
schema,
columns,
}
.into(),
));
.into(),
));

Ok((s_expr, bind_context.clone()))
}
Ok((s_expr, bind_context.clone()))
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,10 @@ impl SubqueryRewriter {
Arc::new(self.rewrite(s_expr.child(0)?)?),
)),

RelOperator::DummyTableScan(_) | RelOperator::Scan(_) | RelOperator::CteScan(_) => {
Ok(s_expr.clone())
}
RelOperator::DummyTableScan(_)
| RelOperator::Scan(_)
| RelOperator::CteScan(_)
| RelOperator::ConstantTableScan(_) => Ok(s_expr.clone()),

_ => Err(ErrorCode::Internal("Invalid plan type")),
}
Expand Down
Loading

0 comments on commit aa827c0

Please sign in to comment.