From 56f0f178db824e145b76c93d1a3f3f13cee8f9d8 Mon Sep 17 00:00:00 2001
From: sundyli <543950155@qq.com>
Date: Fri, 24 May 2024 07:28:48 +0800
Subject: [PATCH] fix(query): unify optimize plan call in one place (#15630)

* fix(query): fix missing exchange in create table as select

* fix(query): fix missing exchange in create table as select

* fix(query): fix missing exchange in create table as select

* fix(query): fix missing exchange in create table as select

* fix(query): fix missing exchange in create table as select

* fix(query): fix missing exchange in create table as select

* fix(query): fix missing exchange in create table as select

* fix(query): fix missing exchange in create table as select
---
 .../interpreter_copy_into_table.rs            |  8 ++-
 .../src/interpreters/interpreter_insert.rs    |  1 +
 src/query/sql/src/planner/binder/ddl/table.rs |  7 +--
 src/query/sql/src/planner/binder/insert.rs    |  8 +--
 .../src/planner/binder/insert_multi_table.rs  |  8 +--
 src/query/sql/src/planner/binder/replace.rs   |  7 +--
 .../sql/src/planner/optimizer/optimizer.rs    | 58 +++++++++++++++++++
 src/query/sql/src/planner/plans/plan.rs       | 27 +++++++++
 .../suites/mode/cluster/create_table.test     | 44 ++++++++++++++
 .../cluster/distributed_copy_into_table.test  | 38 ++++--------
 10 files changed, 153 insertions(+), 53 deletions(-)
 create mode 100644 tests/sqllogictests/suites/mode/cluster/create_table.test

diff --git a/src/query/service/src/interpreters/interpreter_copy_into_table.rs b/src/query/service/src/interpreters/interpreter_copy_into_table.rs
index 68474d6e9181b..33ce659c91b87 100644
--- a/src/query/service/src/interpreters/interpreter_copy_into_table.rs
+++ b/src/query/service/src/interpreters/interpreter_copy_into_table.rs
@@ -106,7 +106,13 @@ impl CopyIntoTableInterpreter {
             .await?;
         let mut update_stream_meta_reqs = vec![];
         let (source, project_columns) = if let Some(ref query) = plan.query {
-            let (query_interpreter, update_stream_meta) = self.build_query(query).await?;
+            let query = if plan.enable_distributed {
+                query.remove_exchange_for_select()
+            } else {
+                *query.clone()
+            };
+
+            let (query_interpreter, update_stream_meta) = self.build_query(&query).await?;
             update_stream_meta_reqs = update_stream_meta;
             let query_physical_plan = Box::new(query_interpreter.build_physical_plan().await?);
 
diff --git a/src/query/service/src/interpreters/interpreter_insert.rs b/src/query/service/src/interpreters/interpreter_insert.rs
index d017e340aecb2..c62cea6ac8c48 100644
--- a/src/query/service/src/interpreters/interpreter_insert.rs
+++ b/src/query/service/src/interpreters/interpreter_insert.rs
@@ -216,6 +216,7 @@ impl Interpreter for InsertInterpreter {
                 let catalog = self.ctx.get_catalog(&self.plan.catalog).await?;
                 let catalog_info = catalog.info();
 
+                // here we remove the last exchange merge plan to trigger distribute insert
                 let insert_select_plan = match select_plan {
                     PhysicalPlan::Exchange(ref mut exchange) => {
                         // insert can be dispatched to different nodes
diff --git a/src/query/sql/src/planner/binder/ddl/table.rs b/src/query/sql/src/planner/binder/ddl/table.rs
index 799c2f6290abb..26114acaa7f2f 100644
--- a/src/query/sql/src/planner/binder/ddl/table.rs
+++ b/src/query/sql/src/planner/binder/ddl/table.rs
@@ -92,8 +92,6 @@ use crate::binder::scalar::ScalarBinder;
 use crate::binder::Binder;
 use crate::binder::ColumnBindingBuilder;
 use crate::binder::Visibility;
-use crate::optimizer::optimize;
-use crate::optimizer::OptimizerContext;
 use crate::parse_computed_expr_to_string;
 use crate::parse_default_expr_to_string;
 use crate::planner::semantic::normalize_identifier;
@@ -661,10 +659,7 @@ impl Binder {
                 let mut bind_context = BindContext::new();
                 let stmt = Statement::Query(Box::new(*query.clone()));
                 let select_plan = self.bind_statement(&mut bind_context, &stmt).await?;
-                // Don't enable distributed optimization for `CREATE TABLE ... AS SELECT ...` for now
-                let opt_ctx = OptimizerContext::new(self.ctx.clone(), self.metadata.clone());
-                let optimized_plan = optimize(opt_ctx, select_plan).await?;
-                Some(Box::new(optimized_plan))
+                Some(Box::new(select_plan))
             } else {
                 None
             },
diff --git a/src/query/sql/src/planner/binder/insert.rs b/src/query/sql/src/planner/binder/insert.rs
index f8e2ac03257bf..e29ab7e09a243 100644
--- a/src/query/sql/src/planner/binder/insert.rs
+++ b/src/query/sql/src/planner/binder/insert.rs
@@ -29,8 +29,6 @@ use databend_common_meta_app::principal::OnErrorMode;
 
 use crate::binder::Binder;
 use crate::normalize_identifier;
-use crate::optimizer::optimize;
-use crate::optimizer::OptimizerContext;
 use crate::plans::insert::InsertValue;
 use crate::plans::CopyIntoTableMode;
 use crate::plans::Insert;
@@ -179,11 +177,7 @@ impl Binder {
             InsertSource::Select { query } => {
                 let statement = Statement::Query(query);
                 let select_plan = self.bind_statement(bind_context, &statement).await?;
-                let opt_ctx = OptimizerContext::new(self.ctx.clone(), self.metadata.clone())
-                    .with_enable_distributed_optimization(!self.ctx.get_cluster().is_empty());
-
-                let optimized_plan = optimize(opt_ctx, select_plan).await?;
-                Ok(InsertInputSource::SelectPlan(Box::new(optimized_plan)))
+                Ok(InsertInputSource::SelectPlan(Box::new(select_plan)))
             }
         };
 
diff --git a/src/query/sql/src/planner/binder/insert_multi_table.rs b/src/query/sql/src/planner/binder/insert_multi_table.rs
index 617e4c5fc0625..bf51fcae548ac 100644
--- a/src/query/sql/src/planner/binder/insert_multi_table.rs
+++ b/src/query/sql/src/planner/binder/insert_multi_table.rs
@@ -26,8 +26,6 @@ use databend_common_expression::DataSchemaRef;
 use databend_common_expression::TableSchema;
 
 use crate::binder::ScalarBinder;
-use crate::optimizer::optimize;
-use crate::optimizer::OptimizerContext;
 use crate::plans::Else;
 use crate::plans::InsertMultiTable;
 use crate::plans::Into;
@@ -62,9 +60,6 @@ impl Binder {
             };
 
             let (s_expr, bind_context) = self.bind_single_table(bind_context, &table_ref).await?;
-            let opt_ctx = OptimizerContext::new(self.ctx.clone(), self.metadata.clone())
-                .with_enable_distributed_optimization(!self.ctx.get_cluster().is_empty());
-
             let select_plan = Plan::Query {
                 s_expr: Box::new(s_expr),
                 metadata: self.metadata.clone(),
@@ -74,8 +69,7 @@ impl Binder {
                 ignore_result: false,
             };
 
-            let optimized_plan = optimize(opt_ctx, select_plan).await?;
-            (optimized_plan, bind_context)
+            (select_plan, bind_context)
         };
 
         let source_schema = input_source.schema();
diff --git a/src/query/sql/src/planner/binder/replace.rs b/src/query/sql/src/planner/binder/replace.rs
index 670672e7e5041..ebc547c023729 100644
--- a/src/query/sql/src/planner/binder/replace.rs
+++ b/src/query/sql/src/planner/binder/replace.rs
@@ -25,8 +25,6 @@ use databend_common_meta_app::principal::OnErrorMode;
 
 use crate::binder::Binder;
 use crate::normalize_identifier;
-use crate::optimizer::optimize;
-use crate::optimizer::OptimizerContext;
 use crate::plans::insert::InsertValue;
 use crate::plans::CopyIntoTableMode;
 use crate::plans::InsertInputSource;
@@ -161,10 +159,7 @@ impl Binder {
             InsertSource::Select { query } => {
                 let statement = Statement::Query(query);
                 let select_plan = self.bind_statement(bind_context, &statement).await?;
-                let opt_ctx = OptimizerContext::new(self.ctx.clone(), self.metadata.clone())
-                    .with_enable_distributed_optimization(false);
-                let optimized_plan = optimize(opt_ctx, select_plan).await?;
-                Ok(InsertInputSource::SelectPlan(Box::new(optimized_plan)))
+                Ok(InsertInputSource::SelectPlan(Box::new(select_plan)))
             }
         };
 
diff --git a/src/query/sql/src/planner/optimizer/optimizer.rs b/src/query/sql/src/planner/optimizer/optimizer.rs
index 0d12cea0b856c..2506f353b1d79 100644
--- a/src/query/sql/src/planner/optimizer/optimizer.rs
+++ b/src/query/sql/src/planner/optimizer/optimizer.rs
@@ -15,6 +15,7 @@
 use std::collections::HashSet;
 use std::sync::Arc;
 
+use async_recursion::async_recursion;
 use databend_common_ast::ast::ExplainKind;
 use databend_common_catalog::merge_into_join::MergeIntoJoin;
 use databend_common_catalog::merge_into_join::MergeIntoJoinType;
@@ -50,6 +51,7 @@ use crate::plans::Join;
 use crate::plans::MergeInto;
 use crate::plans::Plan;
 use crate::plans::RelOperator;
+use crate::InsertInputSource;
 use crate::MetadataRef;
 
 #[derive(Clone, Educe)]
@@ -156,6 +158,7 @@ impl<'a> RecursiveOptimizer<'a> {
 }
 
 #[minitrace::trace]
+#[async_recursion]
 pub async fn optimize(opt_ctx: OptimizerContext, plan: Plan) -> Result<Plan> {
     match plan {
         Plan::Query {
@@ -224,10 +227,65 @@ pub async fn optimize(opt_ctx: OptimizerContext, plan: Plan) -> Result<Plan> {
                 "after optimization enable_distributed_copy? : {}",
                 plan.enable_distributed
             );
+
+            if let Some(p) = &plan.query {
+                let optimized_plan = optimize(opt_ctx.clone(), *p.clone()).await?;
+                plan.query = Some(Box::new(optimized_plan));
+            }
             Ok(Plan::CopyIntoTable(plan))
         }
         Plan::MergeInto(plan) => optimize_merge_into(opt_ctx.clone(), plan).await,
 
+        // distributed insert will be optimized in `physical_plan_builder`
+        Plan::Insert(mut plan) => {
+            match plan.source {
+                InsertInputSource::SelectPlan(p) => {
+                    let optimized_plan = optimize(opt_ctx.clone(), *p.clone()).await?;
+                    plan.source = InsertInputSource::SelectPlan(Box::new(optimized_plan));
+                }
+                InsertInputSource::Stage(p) => {
+                    let optimized_plan = optimize(opt_ctx.clone(), *p.clone()).await?;
+                    plan.source = InsertInputSource::Stage(Box::new(optimized_plan));
+                }
+                _ => {}
+            }
+            Ok(Plan::Insert(plan))
+        }
+        Plan::InsertMultiTable(mut plan) => {
+            plan.input_source = optimize(opt_ctx.clone(), plan.input_source.clone()).await?;
+            Ok(Plan::InsertMultiTable(plan))
+        }
+        Plan::Replace(mut plan) => {
+            match plan.source {
+                InsertInputSource::SelectPlan(p) => {
+                    let optimized_plan = optimize(opt_ctx.clone(), *p.clone()).await?;
+                    plan.source = InsertInputSource::SelectPlan(Box::new(optimized_plan));
+                }
+                InsertInputSource::Stage(p) => {
+                    let optimized_plan = optimize(opt_ctx.clone(), *p.clone()).await?;
+                    plan.source = InsertInputSource::Stage(Box::new(optimized_plan));
+                }
+                _ => {}
+            }
+            Ok(Plan::Replace(plan))
+        }
+
+        Plan::CreateTable(mut plan) => {
+            if let Some(p) = &plan.as_select {
+                let optimized_plan = optimize(opt_ctx.clone(), *p.clone()).await?;
+                plan.as_select = Some(Box::new(optimized_plan));
+            }
+
+            Ok(Plan::CreateTable(plan))
+        }
+        // Already done in binder
+        // Plan::RefreshIndex(mut plan) => {
+        //     // use fresh index
+        //     let opt_ctx =
+        //         OptimizerContext::new(opt_ctx.table_ctx.clone(), opt_ctx.metadata.clone());
+        //     plan.query_plan = Box::new(optimize(opt_ctx.clone(), *plan.query_plan.clone()).await?);
+        //     Ok(Plan::RefreshIndex(plan))
+        // }
         // Pass through statements.
         _ => Ok(plan),
     }
diff --git a/src/query/sql/src/planner/plans/plan.rs b/src/query/sql/src/planner/plans/plan.rs
index 90d984fefed88..13efa48f386df 100644
--- a/src/query/sql/src/planner/plans/plan.rs
+++ b/src/query/sql/src/planner/plans/plan.rs
@@ -24,6 +24,8 @@ use databend_common_expression::DataSchema;
 use databend_common_expression::DataSchemaRef;
 use databend_common_expression::DataSchemaRefExt;
 
+use super::Exchange;
+use super::RelOperator;
 use crate::binder::ExplainConfig;
 use crate::optimizer::SExpr;
 use crate::plans::copy_into_location::CopyIntoLocationPlan;
@@ -488,4 +490,29 @@ impl Plan {
     pub fn has_result_set(&self) -> bool {
         !self.schema().fields().is_empty()
     }
+
+    pub fn remove_exchange_for_select(&self) -> Self {
+        if let Plan::Query {
+            s_expr,
+            metadata,
+            bind_context,
+            rewrite_kind,
+            formatted_ast,
+            ignore_result,
+        } = self
+        {
+            if let RelOperator::Exchange(Exchange::Merge) = s_expr.plan.as_ref() {
+                let s_expr = Box::new(s_expr.child(0).unwrap().clone());
+                return Plan::Query {
+                    s_expr,
+                    metadata: metadata.clone(),
+                    bind_context: bind_context.clone(),
+                    rewrite_kind: rewrite_kind.clone(),
+                    formatted_ast: formatted_ast.clone(),
+                    ignore_result: *ignore_result,
+                };
+            }
+        }
+        self.clone()
+    }
 }
diff --git a/tests/sqllogictests/suites/mode/cluster/create_table.test b/tests/sqllogictests/suites/mode/cluster/create_table.test
new file mode 100644
index 0000000000000..ace5b3c2b23d6
--- /dev/null
+++ b/tests/sqllogictests/suites/mode/cluster/create_table.test
@@ -0,0 +1,44 @@
+query T
+explain create or replace table t2  as select number % 400 d, max(number) from numbers(10000000) group by number  limit 3;
+----
+CreateTableAsSelect:
+(empty)
+EvalScalar
+├── output columns: [max(number) (#6), d (#7)]
+├── expressions: [numbers.number (#4) % 400]
+├── estimated rows: 3.00
+└── Limit
+    ├── output columns: [max(number) (#6), numbers.number (#4)]
+    ├── limit: 3
+    ├── offset: 0
+    ├── estimated rows: 3.00
+    └── Exchange
+        ├── output columns: [max(number) (#6), numbers.number (#4)]
+        ├── exchange type: Merge
+        └── Limit
+            ├── output columns: [max(number) (#6), numbers.number (#4)]
+            ├── limit: 3
+            ├── offset: 0
+            ├── estimated rows: 3.00
+            └── AggregateFinal
+                ├── output columns: [max(number) (#6), numbers.number (#4)]
+                ├── group by: [number]
+                ├── aggregate functions: [max(number)]
+                ├── limit: 3
+                ├── estimated rows: 10000000.00
+                └── Exchange
+                    ├── output columns: [max(number) (#6), numbers.number (#4)]
+                    ├── exchange type: Hash(0)
+                    └── AggregatePartial
+                        ├── group by: [number]
+                        ├── aggregate functions: [max(number)]
+                        ├── estimated rows: 10000000.00
+                        └── TableScan
+                            ├── table: default.system.numbers
+                            ├── output columns: [number (#4)]
+                            ├── read rows: 10000000
+                            ├── read size: 76.29 MiB
+                            ├── partitions total: 153
+                            ├── partitions scanned: 153
+                            ├── push downs: [filters: [], limit: NONE]
+                            └── estimated rows: 10000000.00
diff --git a/tests/sqllogictests/suites/mode/cluster/distributed_copy_into_table.test b/tests/sqllogictests/suites/mode/cluster/distributed_copy_into_table.test
index 94b9a194d4d2e..80172fe09feb1 100644
--- a/tests/sqllogictests/suites/mode/cluster/distributed_copy_into_table.test
+++ b/tests/sqllogictests/suites/mode/cluster/distributed_copy_into_table.test
@@ -1,26 +1,12 @@
 statement ok
 set enable_distributed_copy_into = 1;
 
-statement ok
-drop table if exists test_order;
-
-statement ok
-drop table if exists random_source;
-
-statement ok
-drop stage if exists test_stage;
-
-statement ok
-drop table if exists parquet_table;
-
-statement ok
-drop stage if exists parquet_stage;
 
 statement ok
-create stage st FILE_FORMAT = (TYPE = CSV);
+create or replace stage st FILE_FORMAT = (TYPE = CSV);
 
 statement ok
-create table table_random(a int not null,b string not null,c string not null) ENGINE = Random;
+create or replace table table_random(a int not null,b string not null,c string not null) ENGINE = Random;
 
 statement ok
 copy into @st from (select a,b,c from table_random limit 1000000);
@@ -47,7 +33,7 @@ statement ok
 copy into @st from (select a,b,c from table_random limit 1000000);
 
 statement ok
-create table t(a int not null,b string not null,c string not null);
+create or replace table t(a int not null,b string not null,c string not null);
 
 statement ok
 copy into t from @st force = true;
@@ -74,10 +60,10 @@ statement ok
 set enable_distributed_copy_into = 1;
 
 statement ok
-create table t_query(a int not null,b string not null,c string not null);
+create or replace table t_query(a int not null,b string not null,c string not null);
 
 statement ok
-create stage st_query FILE_FORMAT = (TYPE = TSV);
+create or replace stage st_query FILE_FORMAT = (TYPE = TSV);
 
 statement ok
 copy into @st_query from (select a,b,c from table_random limit 1000000);
@@ -100,10 +86,10 @@ select count(*) from t_query;
 
 ## add parquet_file_test
 statement ok
-create table parquet_table(a int not null,b string not null,c string not null);
+create or replace table parquet_table(a int not null,b string not null,c string not null);
 
 statement ok
-create stage parquet_stage file_format = (type = parquet);
+create or replace stage parquet_stage file_format = (type = parquet);
 
 statement ok
 copy into @parquet_stage from (select a,b,c from table_random limit 100000);
@@ -148,10 +134,10 @@ select count(*) from parquet_table;
 # make sure it's distributed.
 
 statement ok
-create table t_query2(a int not null,b string not null,c string not null);
+create or replace table t_query2(a int not null,b string not null,c string not null);
 
 statement ok
-create stage st_query2 FILE_FORMAT = (TYPE = TSV);
+create or replace stage st_query2 FILE_FORMAT = (TYPE = TSV);
 
 statement ok
 copy into @st_query2 from (select a,b,c from table_random limit 10);
@@ -178,13 +164,13 @@ select block_count from fuse_snapshot('default','t_query2') limit 1;
 
 #test cluster key
 statement ok
-create table test_order(a int not null,b string not null,c timestamp not null) cluster by(to_yyyymmdd(c),a);
+create or replace table test_order(a int not null,b string not null,c timestamp not null) cluster by(to_yyyymmdd(c),a);
 
 statement ok
-create table random_source like test_order Engine = Random;
+create or replace table random_source like test_order Engine = Random;
 
 statement ok
-create stage test_stage;
+create or replace stage test_stage;
 
 statement ok
 copy into @test_stage from (select * from random_source limit 4000000) FILE_FORMAT=(type=parquet);