From b1900cfd9b84b72de7fff54ef31619a4e5639c32 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Sat, 26 Mar 2022 18:35:34 -0400 Subject: [PATCH 01/61] Condition for BinaryExpr, filter, input_ref, rexcall, and rexliteral --- dask_planner/src/expression.rs | 152 +++++++++++++++++++++++++- dask_sql/context.py | 2 +- dask_sql/physical/rex/core/call.py | 1 + dask_sql/physical/rex/core/literal.py | 1 + 4 files changed, 151 insertions(+), 5 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index da10e7230..85b0b29ef 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -121,8 +121,13 @@ impl PyExpr { Expr::Alias(..) => "Alias", Expr::Column(..) => "Column", Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), +<<<<<<< HEAD Expr::Literal(..) => "Literal", Expr::BinaryExpr { .. } => "BinaryExpr", +======= + Expr::Literal(..) => String::from("Literal"), + Expr::BinaryExpr {..} => String::from("BinaryExpr"), +>>>>>>> 9038b85... Condition for BinaryExpr, filter, input_ref, rexcall, and rexliteral Expr::Not(..) => panic!("Not!!!"), Expr::IsNotNull(..) => panic!("IsNotNull!!!"), Expr::Negative(..) => panic!("Negative!!!"), @@ -210,14 +215,80 @@ impl PyExpr { } /// Gets the operands for a BinaryExpr call - #[pyo3(name = "getOperands")] - pub fn get_operands(&self) -> PyResult> { + pub fn getOperands(&self) -> PyResult> { + println!("PyExpr in getOperands(): {:?}", self.expr); match &self.expr { - Expr::BinaryExpr { left, op: _, right } => { + Expr::BinaryExpr {left, op, right} => { + + let left_type: String = match *left.clone() { + Expr::Alias(..) => String::from("Alias"), + Expr::Column(..) => String::from("Column"), + Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), + Expr::Literal(..) => panic!("Literal!!!"), + Expr::BinaryExpr {..} => String::from("BinaryExpr"), + Expr::Not(..) => panic!("Not!!!"), + Expr::IsNotNull(..) => panic!("IsNotNull!!!"), + Expr::Negative(..) => panic!("Negative!!!"), + Expr::GetIndexedField{..} => panic!("GetIndexedField!!!"), + Expr::IsNull(..) => panic!("IsNull!!!"), + Expr::Between{..} => panic!("Between!!!"), + Expr::Case{..} => panic!("Case!!!"), + Expr::Cast{..} => panic!("Cast!!!"), + Expr::TryCast{..} => panic!("TryCast!!!"), + Expr::Sort{..} => panic!("Sort!!!"), + Expr::ScalarFunction{..} => panic!("ScalarFunction!!!"), + Expr::AggregateFunction{..} => panic!("AggregateFunction!!!"), + Expr::WindowFunction{..} => panic!("WindowFunction!!!"), + Expr::AggregateUDF{..} => panic!("AggregateUDF!!!"), + Expr::InList{..} => panic!("InList!!!"), + Expr::Wildcard => panic!("Wildcard!!!"), + _ => String::from("OTHER") + }; + + println!("Left Expression Name: {:?}", left_type); + + let right_type: String = match *right.clone() { + Expr::Alias(..) => String::from("Alias"), + Expr::Column(..) => String::from("Column"), + Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), + Expr::Literal(scalarValue) => { + let value = match scalarValue { + datafusion::scalar::ScalarValue::Int64(value) => { + println!("value: {:?}", value.unwrap()); + String::from("I64 Value") + }, + _ => { + String::from("CatchAll") + } + }; + String::from("Literal") + }, + Expr::BinaryExpr {..} => String::from("BinaryExpr"), + Expr::Not(..) => panic!("Not!!!"), + Expr::IsNotNull(..) => panic!("IsNotNull!!!"), + Expr::Negative(..) => panic!("Negative!!!"), + Expr::GetIndexedField{..} => panic!("GetIndexedField!!!"), + Expr::IsNull(..) => panic!("IsNull!!!"), + Expr::Between{..} => panic!("Between!!!"), + Expr::Case{..} => panic!("Case!!!"), + Expr::Cast{..} => panic!("Cast!!!"), + Expr::TryCast{..} => panic!("TryCast!!!"), + Expr::Sort{..} => panic!("Sort!!!"), + Expr::ScalarFunction{..} => panic!("ScalarFunction!!!"), + Expr::AggregateFunction{..} => panic!("AggregateFunction!!!"), + Expr::WindowFunction{..} => panic!("WindowFunction!!!"), + Expr::AggregateUDF{..} => panic!("AggregateUDF!!!"), + Expr::InList{..} => panic!("InList!!!"), + Expr::Wildcard => panic!("Wildcard!!!"), + _ => String::from("OTHER") + }; + + println!("Right Expression Name: {:?}", right_type); + let mut operands: Vec = Vec::new(); let left_desc: Expr = *left.clone(); - operands.push(left_desc.into()); let right_desc: Expr = *right.clone(); + operands.push(left_desc.into()); operands.push(right_desc.into()); Ok(operands) } @@ -240,6 +311,7 @@ impl PyExpr { } } +<<<<<<< HEAD #[pyo3(name = "getOperatorName")] pub fn get_operator_name(&self) -> PyResult { match &self.expr { @@ -334,6 +406,78 @@ impl PyExpr { _ => panic!("OTHER"), } } +======= + pub fn getOperatorName(&self) -> PyResult { + match &self.expr { + Expr::BinaryExpr { left, op, right } => { + Ok(format!("{}", op)) + }, + _ => Err(PyErr::new::("Catch all triggered ....")) + } + } + + + /// Gets the ScalarValue represented by the Expression + pub fn getValue(&self) -> PyResult { + match &self.expr { + Expr::Alias(..) => panic!("Alias"), + Expr::Column(..) => panic!("Column"), + Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), + Expr::Literal(scalarValue) => { + let value = match scalarValue { + datafusion::scalar::ScalarValue::Int64(value) => { + println!("value: {:?}", value.unwrap()); + let val: i64 = value.unwrap(); + Ok(val) + }, + _ => { + panic!("CatchAll") + } + }; + value + }, + Expr::BinaryExpr {..} => panic!("BinaryExpr"), + Expr::Not(..) => panic!("Not!!!"), + Expr::IsNotNull(..) => panic!("IsNotNull!!!"), + Expr::Negative(..) => panic!("Negative!!!"), + Expr::GetIndexedField{..} => panic!("GetIndexedField!!!"), + Expr::IsNull(..) => panic!("IsNull!!!"), + Expr::Between{..} => panic!("Between!!!"), + Expr::Case{..} => panic!("Case!!!"), + Expr::Cast{..} => panic!("Cast!!!"), + Expr::TryCast{..} => panic!("TryCast!!!"), + Expr::Sort{..} => panic!("Sort!!!"), + Expr::ScalarFunction{..} => panic!("ScalarFunction!!!"), + Expr::AggregateFunction{..} => panic!("AggregateFunction!!!"), + Expr::WindowFunction{..} => panic!("WindowFunction!!!"), + Expr::AggregateUDF{..} => panic!("AggregateUDF!!!"), + Expr::InList{..} => panic!("InList!!!"), + Expr::Wildcard => panic!("Wildcard!!!"), + _ => panic!("OTHER") + } + } + + + /// Gets the ScalarValue represented by the Expression + pub fn getType(&self) -> PyResult { + match &self.expr { + Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), + Expr::Literal(scalarValue) => { + let value = match scalarValue { + datafusion::scalar::ScalarValue::Int64(value) => { + Ok(String::from("BIGINT")) + }, + _ => { + panic!("CatchAll") + } + }; + value + }, + _ => panic!("OTHER") + } + } + +>>>>>>> 9038b85... Condition for BinaryExpr, filter, input_ref, rexcall, and rexliteral #[staticmethod] pub fn column(value: &str) -> PyExpr { diff --git a/dask_sql/context.py b/dask_sql/context.py index 59bbca795..4a590fa5c 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -10,7 +10,7 @@ from dask.base import optimize from dask.distributed import Client -from dask_planner.rust import DaskSchema, DaskSQLContext, DaskTable, Expression +from dask_planner.rust import DaskSchema, DaskSQLContext, DaskTable, LogicalPlan, Expression try: import dask_cuda # noqa: F401 diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 4ef1d64bf..7dc9b3cfa 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -14,6 +14,7 @@ from dask.highlevelgraph import HighLevelGraph from dask.utils import random_state_data +from dask_planner.rust import Expression from dask_sql.datacontainer import DataContainer from dask_sql.mappings import cast_column_to_type, sql_to_python_type from dask_sql.physical.rex import RexConverter diff --git a/dask_sql/physical/rex/core/literal.py b/dask_sql/physical/rex/core/literal.py index b4eb886d1..8c2cdb724 100644 --- a/dask_sql/physical/rex/core/literal.py +++ b/dask_sql/physical/rex/core/literal.py @@ -6,6 +6,7 @@ from dask_sql.datacontainer import DataContainer from dask_sql.mappings import sql_to_python_value from dask_sql.physical.rex.base import BaseRexPlugin +from dask_planner.rust import Expression if TYPE_CHECKING: import dask_sql From 1e485978878a2238af22fb5d97f2f35ebf624296 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 31 Mar 2022 11:34:36 -0400 Subject: [PATCH 02/61] Updates for test_filter --- dask_planner/src/expression.rs | 160 +++----------------------- dask_sql/physical/rex/core/literal.py | 5 + 2 files changed, 24 insertions(+), 141 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 85b0b29ef..028240386 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -148,7 +148,12 @@ impl PyExpr { }) } +<<<<<<< HEAD pub fn column_name(&self, mut plan: logical::PyLogicalPlan) -> String { +======= + + pub fn column_name(&self) -> String { +>>>>>>> 2d16579... Updates for test_filter match &self.expr { Expr::Alias(expr, name) => { println!("Alias encountered with name: {:?}", name); @@ -216,79 +221,12 @@ impl PyExpr { /// Gets the operands for a BinaryExpr call pub fn getOperands(&self) -> PyResult> { - println!("PyExpr in getOperands(): {:?}", self.expr); match &self.expr { Expr::BinaryExpr {left, op, right} => { - - let left_type: String = match *left.clone() { - Expr::Alias(..) => String::from("Alias"), - Expr::Column(..) => String::from("Column"), - Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), - Expr::Literal(..) => panic!("Literal!!!"), - Expr::BinaryExpr {..} => String::from("BinaryExpr"), - Expr::Not(..) => panic!("Not!!!"), - Expr::IsNotNull(..) => panic!("IsNotNull!!!"), - Expr::Negative(..) => panic!("Negative!!!"), - Expr::GetIndexedField{..} => panic!("GetIndexedField!!!"), - Expr::IsNull(..) => panic!("IsNull!!!"), - Expr::Between{..} => panic!("Between!!!"), - Expr::Case{..} => panic!("Case!!!"), - Expr::Cast{..} => panic!("Cast!!!"), - Expr::TryCast{..} => panic!("TryCast!!!"), - Expr::Sort{..} => panic!("Sort!!!"), - Expr::ScalarFunction{..} => panic!("ScalarFunction!!!"), - Expr::AggregateFunction{..} => panic!("AggregateFunction!!!"), - Expr::WindowFunction{..} => panic!("WindowFunction!!!"), - Expr::AggregateUDF{..} => panic!("AggregateUDF!!!"), - Expr::InList{..} => panic!("InList!!!"), - Expr::Wildcard => panic!("Wildcard!!!"), - _ => String::from("OTHER") - }; - - println!("Left Expression Name: {:?}", left_type); - - let right_type: String = match *right.clone() { - Expr::Alias(..) => String::from("Alias"), - Expr::Column(..) => String::from("Column"), - Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), - Expr::Literal(scalarValue) => { - let value = match scalarValue { - datafusion::scalar::ScalarValue::Int64(value) => { - println!("value: {:?}", value.unwrap()); - String::from("I64 Value") - }, - _ => { - String::from("CatchAll") - } - }; - String::from("Literal") - }, - Expr::BinaryExpr {..} => String::from("BinaryExpr"), - Expr::Not(..) => panic!("Not!!!"), - Expr::IsNotNull(..) => panic!("IsNotNull!!!"), - Expr::Negative(..) => panic!("Negative!!!"), - Expr::GetIndexedField{..} => panic!("GetIndexedField!!!"), - Expr::IsNull(..) => panic!("IsNull!!!"), - Expr::Between{..} => panic!("Between!!!"), - Expr::Case{..} => panic!("Case!!!"), - Expr::Cast{..} => panic!("Cast!!!"), - Expr::TryCast{..} => panic!("TryCast!!!"), - Expr::Sort{..} => panic!("Sort!!!"), - Expr::ScalarFunction{..} => panic!("ScalarFunction!!!"), - Expr::AggregateFunction{..} => panic!("AggregateFunction!!!"), - Expr::WindowFunction{..} => panic!("WindowFunction!!!"), - Expr::AggregateUDF{..} => panic!("AggregateUDF!!!"), - Expr::InList{..} => panic!("InList!!!"), - Expr::Wildcard => panic!("Wildcard!!!"), - _ => String::from("OTHER") - }; - - println!("Right Expression Name: {:?}", right_type); - let mut operands: Vec = Vec::new(); let left_desc: Expr = *left.clone(); - let right_desc: Expr = *right.clone(); operands.push(left_desc.into()); + let right_desc: Expr = *right.clone(); operands.push(right_desc.into()); Ok(operands) } @@ -311,7 +249,6 @@ impl PyExpr { } } -<<<<<<< HEAD #[pyo3(name = "getOperatorName")] pub fn get_operator_name(&self) -> PyResult { match &self.expr { @@ -406,78 +343,6 @@ impl PyExpr { _ => panic!("OTHER"), } } -======= - pub fn getOperatorName(&self) -> PyResult { - match &self.expr { - Expr::BinaryExpr { left, op, right } => { - Ok(format!("{}", op)) - }, - _ => Err(PyErr::new::("Catch all triggered ....")) - } - } - - - /// Gets the ScalarValue represented by the Expression - pub fn getValue(&self) -> PyResult { - match &self.expr { - Expr::Alias(..) => panic!("Alias"), - Expr::Column(..) => panic!("Column"), - Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), - Expr::Literal(scalarValue) => { - let value = match scalarValue { - datafusion::scalar::ScalarValue::Int64(value) => { - println!("value: {:?}", value.unwrap()); - let val: i64 = value.unwrap(); - Ok(val) - }, - _ => { - panic!("CatchAll") - } - }; - value - }, - Expr::BinaryExpr {..} => panic!("BinaryExpr"), - Expr::Not(..) => panic!("Not!!!"), - Expr::IsNotNull(..) => panic!("IsNotNull!!!"), - Expr::Negative(..) => panic!("Negative!!!"), - Expr::GetIndexedField{..} => panic!("GetIndexedField!!!"), - Expr::IsNull(..) => panic!("IsNull!!!"), - Expr::Between{..} => panic!("Between!!!"), - Expr::Case{..} => panic!("Case!!!"), - Expr::Cast{..} => panic!("Cast!!!"), - Expr::TryCast{..} => panic!("TryCast!!!"), - Expr::Sort{..} => panic!("Sort!!!"), - Expr::ScalarFunction{..} => panic!("ScalarFunction!!!"), - Expr::AggregateFunction{..} => panic!("AggregateFunction!!!"), - Expr::WindowFunction{..} => panic!("WindowFunction!!!"), - Expr::AggregateUDF{..} => panic!("AggregateUDF!!!"), - Expr::InList{..} => panic!("InList!!!"), - Expr::Wildcard => panic!("Wildcard!!!"), - _ => panic!("OTHER") - } - } - - - /// Gets the ScalarValue represented by the Expression - pub fn getType(&self) -> PyResult { - match &self.expr { - Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), - Expr::Literal(scalarValue) => { - let value = match scalarValue { - datafusion::scalar::ScalarValue::Int64(value) => { - Ok(String::from("BIGINT")) - }, - _ => { - panic!("CatchAll") - } - }; - value - }, - _ => panic!("OTHER") - } - } - ->>>>>>> 9038b85... Condition for BinaryExpr, filter, input_ref, rexcall, and rexliteral #[staticmethod] pub fn column(value: &str) -> PyExpr { @@ -672,7 +537,12 @@ impl PyExpr { // fn getValue(&mut self) -> T; // } +<<<<<<< HEAD // /// Expansion macro to get all typed values from a DataFusion Expr +======= + +// /// Expansion macro to get all typed values from a Datafusion Expr +>>>>>>> 2d16579... Updates for test_filter // macro_rules! get_typed_value { // ($t:ty, $func_name:ident) => { // impl ObtainValue<$t> for PyExpr { @@ -709,6 +579,10 @@ impl PyExpr { // get_typed_value!(f32, Float32); // get_typed_value!(f64, Float64); +<<<<<<< HEAD +======= + +>>>>>>> 2d16579... Updates for test_filter // get_typed_value!(for usize u8 u16 u32 u64 isize i8 i16 i32 i64 bool f32 f64); // get_typed_value!(usize, Integer); // get_typed_value!(isize, ); @@ -721,6 +595,10 @@ impl PyExpr { // Date32(Option), // Date64(Option), +<<<<<<< HEAD +======= + +>>>>>>> 2d16579... Updates for test_filter #[pyproto] impl PyMappingProtocol for PyExpr { fn __getitem__(&self, key: &str) -> PyResult { diff --git a/dask_sql/physical/rex/core/literal.py b/dask_sql/physical/rex/core/literal.py index 8c2cdb724..345dc707d 100644 --- a/dask_sql/physical/rex/core/literal.py +++ b/dask_sql/physical/rex/core/literal.py @@ -1,4 +1,9 @@ +<<<<<<< HEAD import logging +======= +from resource import RUSAGE_THREAD +import tty +>>>>>>> 2d16579... Updates for test_filter from typing import TYPE_CHECKING, Any import dask.dataframe as dd From fd41a8c54172369b30f9f4966b038de1404065dc Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 31 Mar 2022 11:58:35 -0400 Subject: [PATCH 03/61] more of test_filter.py working with the exception of some date pytests --- dask_planner/src/expression.rs | 36 ++++++++++++++++++++------- dask_planner/src/sql.rs | 6 +++++ dask_sql/physical/rex/core/literal.py | 5 ---- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 028240386..f9fbee5f6 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -121,13 +121,8 @@ impl PyExpr { Expr::Alias(..) => "Alias", Expr::Column(..) => "Column", Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), -<<<<<<< HEAD Expr::Literal(..) => "Literal", Expr::BinaryExpr { .. } => "BinaryExpr", -======= - Expr::Literal(..) => String::from("Literal"), - Expr::BinaryExpr {..} => String::from("BinaryExpr"), ->>>>>>> 9038b85... Condition for BinaryExpr, filter, input_ref, rexcall, and rexliteral Expr::Not(..) => panic!("Not!!!"), Expr::IsNotNull(..) => panic!("IsNotNull!!!"), Expr::Negative(..) => panic!("Negative!!!"), @@ -148,12 +143,8 @@ impl PyExpr { }) } -<<<<<<< HEAD - pub fn column_name(&self, mut plan: logical::PyLogicalPlan) -> String { -======= pub fn column_name(&self) -> String { ->>>>>>> 2d16579... Updates for test_filter match &self.expr { Expr::Alias(expr, name) => { println!("Alias encountered with name: {:?}", name); @@ -519,6 +510,7 @@ impl PyExpr { } } +<<<<<<< HEAD #[pyo3(name = "getStringValue")] pub fn string_value(&mut self) -> String { match &self.expr { @@ -531,6 +523,32 @@ impl PyExpr { _ => panic!("getValue() - Non literal value encountered"), } } +======= + pub fn getStringValue(&mut self) -> String { + match &self.expr { + Expr::Literal(scalar_value) => { + match scalar_value { + ScalarValue::Utf8(iv) => { + String::from(iv.clone().unwrap()) + }, + _ => { + panic!("getValue() - Unexpected value") + } + } + }, + _ => panic!("getValue() - Non literal value encountered") + } + } + + +// get_typed_value!(i8, Int8); +// get_typed_value!(i16, Int16); +// get_typed_value!(i32, Int32); +// get_typed_value!(i64, Int64); +// get_typed_value!(bool, Boolean); +// get_typed_value!(f32, Float32); +// get_typed_value!(f64, Float64); +>>>>>>> a4aeee5... more of test_filter.py working with the exception of some date pytests } // pub trait ObtainValue { diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 11d37df72..573ce5b1c 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -15,7 +15,13 @@ use datafusion::sql::parser::DFParser; use datafusion::sql::planner::{ContextProvider, SqlToRel}; use datafusion_expr::ScalarFunctionImplementation; +<<<<<<< HEAD use std::collections::HashMap; +======= +use datafusion::physical_plan::udf::ScalarUDF; +use datafusion::physical_plan::udaf::AggregateUDF; + +>>>>>>> a4aeee5... more of test_filter.py working with the exception of some date pytests use std::sync::Arc; use pyo3::prelude::*; diff --git a/dask_sql/physical/rex/core/literal.py b/dask_sql/physical/rex/core/literal.py index 345dc707d..8c2cdb724 100644 --- a/dask_sql/physical/rex/core/literal.py +++ b/dask_sql/physical/rex/core/literal.py @@ -1,9 +1,4 @@ -<<<<<<< HEAD import logging -======= -from resource import RUSAGE_THREAD -import tty ->>>>>>> 2d16579... Updates for test_filter from typing import TYPE_CHECKING, Any import dask.dataframe as dd From 682c009ce001d6e98e35002324ff8bbaa99eff2c Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Fri, 25 Mar 2022 16:18:56 -0400 Subject: [PATCH 04/61] Add workflow to keep datafusion dev branch up to date (#440) --- .github/workflows/datafusion-sync.yml | 30 +++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 .github/workflows/datafusion-sync.yml diff --git a/.github/workflows/datafusion-sync.yml b/.github/workflows/datafusion-sync.yml new file mode 100644 index 000000000..fd544eeae --- /dev/null +++ b/.github/workflows/datafusion-sync.yml @@ -0,0 +1,30 @@ +name: Keep datafusion branch up to date +on: + push: + branches: + - main + +# When this workflow is queued, automatically cancel any previous running +# or pending jobs +concurrency: + group: datafusion-sync + cancel-in-progress: true + +jobs: + sync-branches: + runs-on: ubuntu-latest + if: github.repository == 'dask-contrib/dask-sql' + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Set up Node + uses: actions/setup-node@v2 + with: + node-version: 12 + - name: Opening pull request + id: pull + uses: tretuna/sync-branches@1.4.0 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + FROM_BRANCH: main + TO_BRANCH: datafusion-sql-planner From ab69dd8e42d350c357c3d02b7f0d697eecb36bf1 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 12 Apr 2022 21:23:01 -0400 Subject: [PATCH 05/61] Include setuptools-rust in conda build recipie, in host and run --- continuous_integration/recipe/meta.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/continuous_integration/recipe/meta.yaml b/continuous_integration/recipe/meta.yaml index 1bfeb19cb..ce68a9c7c 100644 --- a/continuous_integration/recipe/meta.yaml +++ b/continuous_integration/recipe/meta.yaml @@ -30,6 +30,7 @@ requirements: - setuptools-rust>=1.1.2 run: - python + - setuptools-rust>=1.1.2 - dask >=2022.3.0 - pandas >=1.0.0 - fastapi >=0.61.1 From ce4c31e2ab00fa2b366a668570357cb93cb9bf07 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Wed, 20 Apr 2022 18:20:46 -0400 Subject: [PATCH 06/61] Remove PyArrow dependency --- .github/workflows/test.yml | 1 - continuous_integration/environment-3.10-dev.yaml | 1 - continuous_integration/environment-3.8-dev.yaml | 1 - continuous_integration/environment-3.9-dev.yaml | 1 - continuous_integration/recipe/meta.yaml | 1 - docker/conda.txt | 1 - docker/main.dockerfile | 1 - setup.py | 1 - 8 files changed, 8 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cc5b078bd..4c8016b7b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -162,7 +162,6 @@ jobs: - name: Install dependencies and nothing else run: | conda install setuptools-rust - conda install pyarrow>=4.0.0 pip install -e . which python diff --git a/continuous_integration/environment-3.10-dev.yaml b/continuous_integration/environment-3.10-dev.yaml index 51a7f6052..6730402ec 100644 --- a/continuous_integration/environment-3.10-dev.yaml +++ b/continuous_integration/environment-3.10-dev.yaml @@ -23,7 +23,6 @@ dependencies: - pre-commit>=2.11.1 - prompt_toolkit>=3.0.8 - psycopg2>=2.9.1 -- pyarrow>=4.0.0 - pygments>=2.7.1 - pyhive>=0.6.4 - pytest-cov>=2.10.1 diff --git a/continuous_integration/environment-3.8-dev.yaml b/continuous_integration/environment-3.8-dev.yaml index 10132bff6..01dec9ee6 100644 --- a/continuous_integration/environment-3.8-dev.yaml +++ b/continuous_integration/environment-3.8-dev.yaml @@ -23,7 +23,6 @@ dependencies: - pre-commit>=2.11.1 - prompt_toolkit>=3.0.8 - psycopg2>=2.9.1 -- pyarrow>=4.0.0 - pygments>=2.7.1 - pyhive>=0.6.4 - pytest-cov>=2.10.1 diff --git a/continuous_integration/environment-3.9-dev.yaml b/continuous_integration/environment-3.9-dev.yaml index 571a265a7..1b962a19c 100644 --- a/continuous_integration/environment-3.9-dev.yaml +++ b/continuous_integration/environment-3.9-dev.yaml @@ -25,7 +25,6 @@ dependencies: - pre-commit>=2.11.1 - prompt_toolkit>=3.0.8 - psycopg2>=2.9.1 -- pyarrow>=4.0.0 - pygments>=2.7.1 - pyhive>=0.6.4 - pytest-cov>=2.10.1 diff --git a/continuous_integration/recipe/meta.yaml b/continuous_integration/recipe/meta.yaml index ce68a9c7c..6d6ef2ced 100644 --- a/continuous_integration/recipe/meta.yaml +++ b/continuous_integration/recipe/meta.yaml @@ -40,7 +40,6 @@ requirements: - pygments - nest-asyncio - tabulate - - pyarrow>=4.0.0 test: imports: diff --git a/docker/conda.txt b/docker/conda.txt index 81fc96a9d..ddcac2de8 100644 --- a/docker/conda.txt +++ b/docker/conda.txt @@ -13,7 +13,6 @@ tzlocal>=2.1 fastapi>=0.61.1 nest-asyncio>=1.4.3 uvicorn>=0.11.3 -pyarrow>=4.0.0 prompt_toolkit>=3.0.8 pygments>=2.7.1 dask-ml>=2022.1.22 diff --git a/docker/main.dockerfile b/docker/main.dockerfile index e69ef79a3..6f5a54b8e 100644 --- a/docker/main.dockerfile +++ b/docker/main.dockerfile @@ -13,7 +13,6 @@ RUN conda config --add channels conda-forge \ "tzlocal>=2.1" \ "fastapi>=0.61.1" \ "uvicorn>=0.11.3" \ - "pyarrow>=4.0.0" \ "prompt_toolkit>=3.0.8" \ "pygments>=2.7.1" \ "dask-ml>=2022.1.22" \ diff --git a/setup.py b/setup.py index 8aa1c216a..abe657eb9 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,6 @@ "pytest-cov>=2.10.1", "mock>=4.0.3", "sphinx>=3.2.1", - "pyarrow==7.0.0", "dask-ml>=2022.1.22", "scikit-learn>=0.24.2", "intake>=0.6.0", From 8785b8c3713e70a6a52c3a71e29ba78a47a435c5 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 21 Apr 2022 14:18:08 -0400 Subject: [PATCH 07/61] rebase with datafusion-sql-planner --- dask_planner/src/expression.rs | 241 +++++++++++++-------------------- dask_planner/src/sql.rs | 6 - 2 files changed, 92 insertions(+), 155 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index f9fbee5f6..5ec66bee6 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -105,77 +105,58 @@ impl PyObjectProtocol for PyExpr { } } -#[pymethods] impl PyExpr { - #[staticmethod] - pub fn literal(value: PyScalarValue) -> PyExpr { - lit(value.scalar_value).into() - } - - /// Examine the current/"self" PyExpr and return its "type" - /// In this context a "type" is what Dask-SQL Python - /// RexConverter plugin instance should be invoked to handle - /// the Rex conversion - pub fn get_expr_type(&self) -> String { - String::from(match &self.expr { - Expr::Alias(..) => "Alias", - Expr::Column(..) => "Column", - Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), - Expr::Literal(..) => "Literal", - Expr::BinaryExpr { .. } => "BinaryExpr", - Expr::Not(..) => panic!("Not!!!"), - Expr::IsNotNull(..) => panic!("IsNotNull!!!"), - Expr::Negative(..) => panic!("Negative!!!"), - Expr::GetIndexedField { .. } => panic!("GetIndexedField!!!"), - Expr::IsNull(..) => panic!("IsNull!!!"), - Expr::Between { .. } => panic!("Between!!!"), - Expr::Case { .. } => panic!("Case!!!"), - Expr::Cast { .. } => "Cast", - Expr::TryCast { .. } => panic!("TryCast!!!"), - Expr::Sort { .. } => panic!("Sort!!!"), - Expr::ScalarFunction { .. } => "ScalarFunction", - Expr::AggregateFunction { .. } => "AggregateFunction", - Expr::WindowFunction { .. } => panic!("WindowFunction!!!"), - Expr::AggregateUDF { .. } => panic!("AggregateUDF!!!"), - Expr::InList { .. } => panic!("InList!!!"), - Expr::Wildcard => panic!("Wildcard!!!"), - _ => "OTHER", - }) - } - - - pub fn column_name(&self) -> String { + fn _column_name(&self, mut plan: LogicalPlan) -> String { match &self.expr { Expr::Alias(expr, name) => { println!("Alias encountered with name: {:?}", name); + // let reference: Expr = *expr.as_ref(); + // let plan: logical::PyLogicalPlan = reference.input().clone().into(); // Only certain LogicalPlan variants are valid in this nested Alias scenario so we // extract the valid ones and error on the invalid ones match expr.as_ref() { Expr::Column(col) => { // First we must iterate the current node before getting its input - match plan.current_node() { - LogicalPlan::Projection(proj) => match proj.input.as_ref() { - LogicalPlan::Aggregate(agg) => { - let mut exprs = agg.group_expr.clone(); - exprs.extend_from_slice(&agg.aggr_expr); - match &exprs[plan.get_index(col)] { - Expr::AggregateFunction { args, .. } => match &args[0] { - Expr::Column(col) => { - println!("AGGREGATE COLUMN IS {}", col.name); - col.name.clone() + match plan { + LogicalPlan::Projection(proj) => { + match proj.input.as_ref() { + LogicalPlan::Aggregate(agg) => { + let mut exprs = agg.group_expr.clone(); + exprs.extend_from_slice(&agg.aggr_expr); + let col_index: usize = + proj.input.schema().index_of_column(col).unwrap(); + // match &exprs[plan.get_index(col)] { + match &exprs[col_index] { + Expr::AggregateFunction { args, .. } => { + match &args[0] { + Expr::Column(col) => { + println!( + "AGGREGATE COLUMN IS {}", + col.name + ); + col.name.clone() + } + _ => name.clone(), + } } _ => name.clone(), - }, - _ => name.clone(), + } + } + _ => { + println!("Encountered a non-Aggregate type"); + + name.clone() } } - _ => name.clone(), - }, + } _ => name.clone(), } } - _ => name.clone(), + _ => { + println!("Encountered a non Expr::Column instance"); + name.clone() + } } } Expr::Column(column) => column.name.clone(), @@ -209,11 +190,66 @@ impl PyExpr { _ => panic!("Nothing found!!!"), } } +} + +#[pymethods] +impl PyExpr { + #[staticmethod] + pub fn literal(value: PyScalarValue) -> PyExpr { + lit(value.scalar_value).into() + } + + /// If this Expression instances references an existing + /// Column in the SQL parse tree or not + #[pyo3(name = "isInputReference")] + pub fn is_input_reference(&self) -> PyResult { + match &self.expr { + Expr::Column(_col) => Ok(true), + _ => Ok(false), + } + } + + /// Examine the current/"self" PyExpr and return its "type" + /// In this context a "type" is what Dask-SQL Python + /// RexConverter plugin instance should be invoked to handle + /// the Rex conversion + pub fn get_expr_type(&self) -> String { + String::from(match &self.expr { + Expr::Alias(..) => "Alias", + Expr::Column(..) => "Column", + Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), + Expr::Literal(..) => "Literal", + Expr::BinaryExpr { .. } => "BinaryExpr", + Expr::Not(..) => panic!("Not!!!"), + Expr::IsNotNull(..) => panic!("IsNotNull!!!"), + Expr::Negative(..) => panic!("Negative!!!"), + Expr::GetIndexedField { .. } => panic!("GetIndexedField!!!"), + Expr::IsNull(..) => panic!("IsNull!!!"), + Expr::Between { .. } => panic!("Between!!!"), + Expr::Case { .. } => panic!("Case!!!"), + Expr::Cast { .. } => "Cast", + Expr::TryCast { .. } => panic!("TryCast!!!"), + Expr::Sort { .. } => panic!("Sort!!!"), + Expr::ScalarFunction { .. } => "ScalarFunction", + Expr::AggregateFunction { .. } => "AggregateFunction", + Expr::WindowFunction { .. } => panic!("WindowFunction!!!"), + Expr::AggregateUDF { .. } => panic!("AggregateUDF!!!"), + Expr::InList { .. } => panic!("InList!!!"), + Expr::Wildcard => panic!("Wildcard!!!"), + _ => "OTHER", + }) + } + + /// Python friendly shim code to get the name of a column referenced by an expression + pub fn column_name(&self, mut plan: logical::PyLogicalPlan) -> String { + self._column_name(plan.current_node()) + } /// Gets the operands for a BinaryExpr call - pub fn getOperands(&self) -> PyResult> { + #[pyo3(name = "getOperands")] + pub fn get_operands(&self) -> PyResult> { match &self.expr { - Expr::BinaryExpr {left, op, right} => { + Expr::BinaryExpr { left, op: _, right } => { let mut operands: Vec = Vec::new(); let left_desc: Expr = *left.clone(); operands.push(left_desc.into()); @@ -510,7 +546,6 @@ impl PyExpr { } } -<<<<<<< HEAD #[pyo3(name = "getStringValue")] pub fn string_value(&mut self) -> String { match &self.expr { @@ -523,100 +558,8 @@ impl PyExpr { _ => panic!("getValue() - Non literal value encountered"), } } -======= - pub fn getStringValue(&mut self) -> String { - match &self.expr { - Expr::Literal(scalar_value) => { - match scalar_value { - ScalarValue::Utf8(iv) => { - String::from(iv.clone().unwrap()) - }, - _ => { - panic!("getValue() - Unexpected value") - } - } - }, - _ => panic!("getValue() - Non literal value encountered") - } - } - - -// get_typed_value!(i8, Int8); -// get_typed_value!(i16, Int16); -// get_typed_value!(i32, Int32); -// get_typed_value!(i64, Int64); -// get_typed_value!(bool, Boolean); -// get_typed_value!(f32, Float32); -// get_typed_value!(f64, Float64); ->>>>>>> a4aeee5... more of test_filter.py working with the exception of some date pytests } -// pub trait ObtainValue { -// fn getValue(&mut self) -> T; -// } - -<<<<<<< HEAD -// /// Expansion macro to get all typed values from a DataFusion Expr -======= - -// /// Expansion macro to get all typed values from a Datafusion Expr ->>>>>>> 2d16579... Updates for test_filter -// macro_rules! get_typed_value { -// ($t:ty, $func_name:ident) => { -// impl ObtainValue<$t> for PyExpr { -// #[inline] -// fn getValue(&mut self) -> $t -// { -// match &self.expr { -// Expr::Literal(scalar_value) => { -// match scalar_value { -// ScalarValue::$func_name(iv) => { -// iv.unwrap() -// }, -// _ => { -// panic!("getValue() - Unexpected value") -// } -// } -// }, -// _ => panic!("getValue() - Non literal value encountered") -// } -// } -// } -// } -// } - -// get_typed_value!(u8, UInt8); -// get_typed_value!(u16, UInt16); -// get_typed_value!(u32, UInt32); -// get_typed_value!(u64, UInt64); -// get_typed_value!(i8, Int8); -// get_typed_value!(i16, Int16); -// get_typed_value!(i32, Int32); -// get_typed_value!(i64, Int64); -// get_typed_value!(bool, Boolean); -// get_typed_value!(f32, Float32); -// get_typed_value!(f64, Float64); - -<<<<<<< HEAD -======= - ->>>>>>> 2d16579... Updates for test_filter -// get_typed_value!(for usize u8 u16 u32 u64 isize i8 i16 i32 i64 bool f32 f64); -// get_typed_value!(usize, Integer); -// get_typed_value!(isize, ); -// Decimal128(Option, usize, usize), -// Utf8(Option), -// LargeUtf8(Option), -// Binary(Option>), -// LargeBinary(Option>), -// List(Option, Global>>, Box), -// Date32(Option), -// Date64(Option), - -<<<<<<< HEAD -======= - ->>>>>>> 2d16579... Updates for test_filter #[pyproto] impl PyMappingProtocol for PyExpr { fn __getitem__(&self, key: &str) -> PyResult { diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 573ce5b1c..11d37df72 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -15,13 +15,7 @@ use datafusion::sql::parser::DFParser; use datafusion::sql::planner::{ContextProvider, SqlToRel}; use datafusion_expr::ScalarFunctionImplementation; -<<<<<<< HEAD use std::collections::HashMap; -======= -use datafusion::physical_plan::udf::ScalarUDF; -use datafusion::physical_plan::udaf::AggregateUDF; - ->>>>>>> a4aeee5... more of test_filter.py working with the exception of some date pytests use std::sync::Arc; use pyo3::prelude::*; From 3e45ab8d692f59bc33e9708e1b309cc29a08ced3 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 21 Apr 2022 14:24:14 -0400 Subject: [PATCH 08/61] refactor changes that were inadvertent during rebase --- .github/workflows/datafusion-sync.yml | 30 --------------------------- dask_sql/context.py | 9 +++++--- dask_sql/physical/rex/core/call.py | 1 - dask_sql/physical/rex/core/literal.py | 1 - 4 files changed, 6 insertions(+), 35 deletions(-) delete mode 100644 .github/workflows/datafusion-sync.yml diff --git a/.github/workflows/datafusion-sync.yml b/.github/workflows/datafusion-sync.yml deleted file mode 100644 index fd544eeae..000000000 --- a/.github/workflows/datafusion-sync.yml +++ /dev/null @@ -1,30 +0,0 @@ -name: Keep datafusion branch up to date -on: - push: - branches: - - main - -# When this workflow is queued, automatically cancel any previous running -# or pending jobs -concurrency: - group: datafusion-sync - cancel-in-progress: true - -jobs: - sync-branches: - runs-on: ubuntu-latest - if: github.repository == 'dask-contrib/dask-sql' - steps: - - name: Checkout - uses: actions/checkout@v2 - - name: Set up Node - uses: actions/setup-node@v2 - with: - node-version: 12 - - name: Opening pull request - id: pull - uses: tretuna/sync-branches@1.4.0 - with: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - FROM_BRANCH: main - TO_BRANCH: datafusion-sql-planner diff --git a/dask_sql/context.py b/dask_sql/context.py index 4a590fa5c..89912d968 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -2,7 +2,7 @@ import inspect import logging import warnings -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union import dask.dataframe as dd import pandas as pd @@ -10,7 +10,7 @@ from dask.base import optimize from dask.distributed import Client -from dask_planner.rust import DaskSchema, DaskSQLContext, DaskTable, LogicalPlan, Expression +from dask_planner.rust import DaskSchema, DaskSQLContext, DaskTable try: import dask_cuda # noqa: F401 @@ -31,6 +31,9 @@ from dask_sql.physical.rel import RelConverter, custom, logical from dask_sql.physical.rex import RexConverter, core +if TYPE_CHECKING: + from dask_planner.rust import Expression + logger = logging.getLogger(__name__) @@ -688,7 +691,7 @@ def stop_server(self): # pragma: no cover self.sql_server = None - def fqn(self, identifier: Expression) -> Tuple[str, str]: + def fqn(self, identifier: "Expression") -> Tuple[str, str]: """ Return the fully qualified name of an object, maybe including the schema name. diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 7dc9b3cfa..4ef1d64bf 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -14,7 +14,6 @@ from dask.highlevelgraph import HighLevelGraph from dask.utils import random_state_data -from dask_planner.rust import Expression from dask_sql.datacontainer import DataContainer from dask_sql.mappings import cast_column_to_type, sql_to_python_type from dask_sql.physical.rex import RexConverter diff --git a/dask_sql/physical/rex/core/literal.py b/dask_sql/physical/rex/core/literal.py index 8c2cdb724..b4eb886d1 100644 --- a/dask_sql/physical/rex/core/literal.py +++ b/dask_sql/physical/rex/core/literal.py @@ -6,7 +6,6 @@ from dask_sql.datacontainer import DataContainer from dask_sql.mappings import sql_to_python_value from dask_sql.physical.rex.base import BaseRexPlugin -from dask_planner.rust import Expression if TYPE_CHECKING: import dask_sql From 1734b89da061a4ab66d6c7e8202eadbf0d0d41a3 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 21 Apr 2022 14:51:01 -0400 Subject: [PATCH 09/61] timestamp with loglca time zone --- dask_sql/mappings.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 9efa908dc..edd919085 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -6,7 +6,6 @@ import dask.dataframe as dd import numpy as np import pandas as pd -import pyarrow as pa from dask_sql._compat import FLOAT_NAN_IMPLEMENTED @@ -88,7 +87,7 @@ def python_to_sql_type(python_type): python_type = python_type.type if pd.api.types.is_datetime64tz_dtype(python_type): - return pa.timestamp("ms", tz="UTC") + return "TIMESTAMP_WITH_LOCAL_TIME_ZONE" try: return _PYTHON_TO_SQL[python_type] From ac7d9f6f92b6c67d3f6d7ca4207f0214cf6ffc6e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 21 Apr 2022 14:12:54 -0600 Subject: [PATCH 10/61] Bump DataFusion version (#494) * bump DataFusion version * remove unnecessary downcasts and use separate structs for TableSource and TableProvider --- dask_planner/Cargo.toml | 4 +- dask_planner/src/expression.rs | 6 +-- dask_planner/src/sql.rs | 10 ++-- dask_planner/src/sql/logical.rs | 2 +- dask_planner/src/sql/logical/aggregate.rs | 5 +- dask_planner/src/sql/logical/filter.rs | 4 +- dask_planner/src/sql/logical/join.rs | 25 ++-------- dask_planner/src/sql/logical/projection.rs | 6 +-- dask_planner/src/sql/table.rs | 54 +++++++++++++--------- 9 files changed, 52 insertions(+), 64 deletions(-) diff --git a/dask_planner/Cargo.toml b/dask_planner/Cargo.toml index 0ed665bf6..c66fc637c 100644 --- a/dask_planner/Cargo.toml +++ b/dask_planner/Cargo.toml @@ -12,8 +12,8 @@ rust-version = "1.59" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } rand = "0.7" pyo3 = { version = "0.15", features = ["extension-module", "abi3", "abi3-py38"] } -datafusion = { git="https://github.com/apache/arrow-datafusion/", rev = "583b4ab8dfe6148a7387841d112dd50b1151f6fb" } -datafusion-expr = { git="https://github.com/apache/arrow-datafusion/", rev = "583b4ab8dfe6148a7387841d112dd50b1151f6fb" } +datafusion = { git="https://github.com/apache/arrow-datafusion/", rev = "23f1c77569d1f3b0ff42ade56f9b2afb53d44292" } +datafusion-expr = { git="https://github.com/apache/arrow-datafusion/", rev = "23f1c77569d1f3b0ff42ade56f9b2afb53d44292" } uuid = { version = "0.8", features = ["v4"] } mimalloc = { version = "*", default-features = false } sqlparser = "0.14.0" diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 5ec66bee6..31043f118 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -6,13 +6,11 @@ use pyo3::{basic::CompareOp, prelude::*, PyNumberProtocol, PyObjectProtocol}; use std::convert::{From, Into}; use datafusion::arrow::datatypes::DataType; -use datafusion::logical_plan::{col, lit, Expr}; +use datafusion_expr::{col, lit, BuiltinScalarFunction, Expr}; use datafusion::scalar::ScalarValue; -pub use datafusion::logical_plan::plan::LogicalPlan; - -use datafusion::logical_expr::BuiltinScalarFunction; +pub use datafusion_expr::LogicalPlan; /// An PyExpr that can be used on a DataFrame #[pyclass(name = "Expression", module = "datafusion", subclass)] diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 11d37df72..90d8f8401 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -18,6 +18,7 @@ use datafusion_expr::ScalarFunctionImplementation; use std::collections::HashMap; use std::sync::Arc; +use crate::sql::table::DaskTableProvider; use pyo3::prelude::*; /// DaskSQLContext is main interface used for interacting with DataFusion to @@ -51,7 +52,6 @@ impl ContextProvider for DaskSQLContext { match self.schemas.get(&self.default_schema_name) { Some(schema) => { let mut resp = None; - let mut table_name: String = "".to_string(); for (_table_name, table) in &schema.tables { if table.name.eq(&name.table()) { // Build the Schema here @@ -67,13 +67,11 @@ impl ContextProvider for DaskSQLContext { } resp = Some(Schema::new(fields)); - table_name = _table_name.clone(); } } - Some(Arc::new(table::DaskTableProvider::new( - Arc::new(resp.unwrap()), - table_name, - ))) + Some(Arc::new(DaskTableProvider::new(Arc::new( + table::DaskTableSource::new(Arc::new(resp.unwrap())), + )))) } None => { DataFusionError::Execution(format!( diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index c3d67b50d..283aab317 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -4,7 +4,7 @@ mod filter; mod join; pub mod projection; -pub use datafusion::logical_plan::plan::LogicalPlan; +pub use datafusion_expr::LogicalPlan; use datafusion::prelude::Column; diff --git a/dask_planner/src/sql/logical/aggregate.rs b/dask_planner/src/sql/logical/aggregate.rs index 8d3aac2ab..e260e4bd7 100644 --- a/dask_planner/src/sql/logical/aggregate.rs +++ b/dask_planner/src/sql/logical/aggregate.rs @@ -1,8 +1,7 @@ use crate::expression::PyExpr; -use datafusion::logical_plan::Expr; -use datafusion::logical_plan::plan::Aggregate; -pub use datafusion::logical_plan::plan::{JoinType, LogicalPlan}; +use datafusion_expr::{logical_plan::Aggregate, Expr}; +pub use datafusion_expr::{logical_plan::JoinType, LogicalPlan}; use pyo3::prelude::*; diff --git a/dask_planner/src/sql/logical/filter.rs b/dask_planner/src/sql/logical/filter.rs index a8ef32d5b..2ef163721 100644 --- a/dask_planner/src/sql/logical/filter.rs +++ b/dask_planner/src/sql/logical/filter.rs @@ -1,7 +1,7 @@ use crate::expression::PyExpr; -use datafusion::logical_plan::plan::Filter; -pub use datafusion::logical_plan::plan::LogicalPlan; +use datafusion_expr::logical_plan::Filter; +pub use datafusion_expr::LogicalPlan; use pyo3::prelude::*; diff --git a/dask_planner/src/sql/logical/join.rs b/dask_planner/src/sql/logical/join.rs index 7558a0fc7..ccb77ef6b 100644 --- a/dask_planner/src/sql/logical/join.rs +++ b/dask_planner/src/sql/logical/join.rs @@ -1,8 +1,7 @@ use crate::sql::column; -use crate::sql::table; -use datafusion::logical_plan::plan::Join; -pub use datafusion::logical_plan::plan::{JoinType, LogicalPlan}; +use datafusion_expr::logical_plan::Join; +pub use datafusion_expr::{logical_plan::JoinType, LogicalPlan}; use pyo3::prelude::*; @@ -17,28 +16,12 @@ impl PyJoin { #[pyo3(name = "getJoinConditions")] pub fn join_conditions(&mut self) -> PyResult> { let lhs_table_name: String = match &*self.join.left { - LogicalPlan::TableScan(_table_scan) => { - let tbl: String = _table_scan - .source - .as_any() - .downcast_ref::() - .unwrap() - .table_name(); - tbl - } + LogicalPlan::TableScan(scan) => scan.table_name.clone(), _ => panic!("lhs Expected TableScan but something else was received!"), }; let rhs_table_name: String = match &*self.join.right { - LogicalPlan::TableScan(_table_scan) => { - let tbl: String = _table_scan - .source - .as_any() - .downcast_ref::() - .unwrap() - .table_name(); - tbl - } + LogicalPlan::TableScan(scan) => scan.table_name.clone(), _ => panic!("rhs Expected TableScan but something else was received!"), }; diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index c1eaba792..d5ef65827 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -1,9 +1,7 @@ use crate::expression::PyExpr; -pub use datafusion::logical_plan::plan::LogicalPlan; -use datafusion::logical_plan::plan::Projection; - -use datafusion::logical_plan::Expr; +pub use datafusion_expr::LogicalPlan; +use datafusion_expr::{logical_plan::Projection, Expr}; use pyo3::prelude::*; diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index 299e43dd8..5ee4ec0e3 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -6,32 +6,50 @@ use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; pub use datafusion::datasource::TableProvider; use datafusion::error::DataFusionError; -use datafusion::logical_plan::plan::LogicalPlan; -use datafusion::logical_plan::Expr; use datafusion::physical_plan::{empty::EmptyExec, project_schema, ExecutionPlan}; +use datafusion_expr::{Expr, LogicalPlan, TableSource}; use pyo3::prelude::*; use std::any::Any; use std::sync::Arc; -/// DaskTableProvider -pub struct DaskTableProvider { +/// DaskTable wrapper that is compatible with DataFusion logical query plans +pub struct DaskTableSource { schema: SchemaRef, - table_name: String, } -impl DaskTableProvider { +impl DaskTableSource { /// Initialize a new `EmptyTable` from a schema. - pub fn new(schema: SchemaRef, table_name: String) -> Self { - Self { schema, table_name } + pub fn new(schema: SchemaRef) -> Self { + Self { schema } + } +} + +/// Implement TableSource, used in the logical query plan and in logical query optimizations +#[async_trait] +impl TableSource for DaskTableSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() } +} - pub fn table_name(&self) -> String { - self.table_name.clone() +/// DaskTable wrapper that is compatible with DataFusion physical query plans +pub struct DaskTableProvider { + source: Arc, +} + +impl DaskTableProvider { + pub fn new(source: Arc) -> Self { + Self { source } } } +/// Implement TableProvider, used for physical query plans and execution #[async_trait] impl TableProvider for DaskTableProvider { fn as_any(&self) -> &dyn Any { @@ -39,7 +57,7 @@ impl TableProvider for DaskTableProvider { } fn schema(&self) -> SchemaRef { - self.schema.clone() + self.source.schema.clone() } async fn scan( @@ -49,7 +67,7 @@ impl TableProvider for DaskTableProvider { _limit: Option, ) -> Result, DataFusionError> { // even though there is no data, projections apply - let projected_schema = project_schema(&self.schema, projection.as_ref())?; + let projected_schema = project_schema(&self.source.schema, projection.as_ref())?; Ok(Arc::new(EmptyExec::new(false, projected_schema))) } } @@ -102,14 +120,8 @@ impl DaskTable { let mut qualified_name = Vec::from([String::from("root")]); match plan.original_plan { - LogicalPlan::TableScan(_table_scan) => { - let tbl = _table_scan - .source - .as_any() - .downcast_ref::() - .unwrap() - .table_name(); - qualified_name.push(tbl); + LogicalPlan::TableScan(table_scan) => { + qualified_name.push(table_scan.table_name); } _ => { println!("Nothing matches"); @@ -152,7 +164,7 @@ pub(crate) fn table_from_logical_plan(plan: &LogicalPlan) -> Option { LogicalPlan::Filter(filter) => table_from_logical_plan(&filter.input), LogicalPlan::TableScan(table_scan) => { // Get the TableProvider for this Table instance - let tbl_provider: Arc = table_scan.source.clone(); + let tbl_provider: Arc = table_scan.source.clone(); let tbl_schema: SchemaRef = tbl_provider.schema(); let fields = tbl_schema.fields(); From cbf5db02aed00bae50acb1012b663c3a46bf8772 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 21 Apr 2022 17:08:12 -0400 Subject: [PATCH 11/61] Include RelDataType work --- dask_planner/src/lib.rs | 10 +- dask_planner/src/sql.rs | 20 +- dask_planner/src/sql/exceptions.rs | 3 + dask_planner/src/sql/table.rs | 40 ++-- dask_planner/src/sql/types.rs | 204 +++++++++++------- dask_planner/src/sql/types/rel_data_type.rs | 111 ++++++++++ .../src/sql/types/rel_data_type_field.rs | 70 ++++++ dask_sql/context.py | 13 +- dask_sql/physical/rel/base.py | 32 ++- tests/integration/test_select.py | 48 ++--- 10 files changed, 408 insertions(+), 143 deletions(-) create mode 100644 dask_planner/src/sql/exceptions.rs create mode 100644 dask_planner/src/sql/types/rel_data_type.rs create mode 100644 dask_planner/src/sql/types/rel_data_type_field.rs diff --git a/dask_planner/src/lib.rs b/dask_planner/src/lib.rs index df9343151..35189bace 100644 --- a/dask_planner/src/lib.rs +++ b/dask_planner/src/lib.rs @@ -13,11 +13,11 @@ static GLOBAL: MiMalloc = MiMalloc; /// dask_planner directory. #[pymodule] #[pyo3(name = "rust")] -fn rust(_py: Python, m: &PyModule) -> PyResult<()> { +fn rust(py: Python, m: &PyModule) -> PyResult<()> { // Register the python classes m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -25,5 +25,11 @@ fn rust(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + // Exceptions + m.add( + "DFParsingException", + py.get_type::(), + )?; + Ok(()) } diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 90d8f8401..9056f36a2 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -1,4 +1,5 @@ pub mod column; +pub mod exceptions; pub mod function; pub mod logical; pub mod schema; @@ -6,6 +7,8 @@ pub mod statement; pub mod table; pub mod types; +use crate::sql::exceptions::ParsingException; + use datafusion::arrow::datatypes::{Field, Schema}; use datafusion::catalog::TableReference; use datafusion::error::DataFusionError; @@ -57,14 +60,15 @@ impl ContextProvider for DaskSQLContext { // Build the Schema here let mut fields: Vec = Vec::new(); - // Iterate through the DaskTable instance and create a Schema instance - for (column_name, column_type) in &table.columns { - fields.push(Field::new( - column_name, - column_type.sql_type.clone(), - false, - )); - } + panic!("Uncomment this section .... before running"); + // // Iterate through the DaskTable instance and create a Schema instance + // for (column_name, column_type) in &table.columns { + // fields.push(Field::new( + // column_name, + // column_type.sql_type.clone(), + // false, + // )); + // } resp = Some(Schema::new(fields)); } diff --git a/dask_planner/src/sql/exceptions.rs b/dask_planner/src/sql/exceptions.rs new file mode 100644 index 000000000..e53aeb5b4 --- /dev/null +++ b/dask_planner/src/sql/exceptions.rs @@ -0,0 +1,3 @@ +use pyo3::create_exception; + +create_exception!(rust, ParsingException, pyo3::exceptions::PyException); diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index 5ee4ec0e3..9e0873fce 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -1,5 +1,5 @@ use crate::sql::logical; -use crate::sql::types; +use crate::sql::types::rel_data_type::RelDataType; use async_trait::async_trait; @@ -93,7 +93,7 @@ pub struct DaskTable { pub(crate) name: String, #[allow(dead_code)] pub(crate) statistics: DaskStatistics, - pub(crate) columns: Vec<(String, types::DaskRelDataType)>, + pub(crate) columns: Vec<(String, RelDataType)>, } #[pymethods] @@ -108,12 +108,13 @@ impl DaskTable { } pub fn add_column(&mut self, column_name: String, column_type_str: String) { - let sql_type: types::DaskRelDataType = types::DaskRelDataType { - name: String::from(&column_name), - sql_type: types::sql_type_to_arrow_type(column_type_str), - }; + panic!("Need to uncomment and fix this before running!!"); + // let sql_type: RelDataType = RelDataType { + // name: String::from(&column_name), + // sql_type: types::sql_type_to_arrow_type(column_type_str), + // }; - self.columns.push((column_name, sql_type)); + // self.columns.push((column_name, sql_type)); } pub fn get_qualified_name(&self, plan: logical::PyLogicalPlan) -> Vec { @@ -140,8 +141,8 @@ impl DaskTable { cns } - pub fn column_types(&self) -> Vec { - let mut col_types: Vec = Vec::new(); + pub fn column_types(&self) -> Vec { + let mut col_types: Vec = Vec::new(); for col in &self.columns { col_types.push(col.1.clone()) } @@ -168,16 +169,17 @@ pub(crate) fn table_from_logical_plan(plan: &LogicalPlan) -> Option { let tbl_schema: SchemaRef = tbl_provider.schema(); let fields = tbl_schema.fields(); - let mut cols: Vec<(String, types::DaskRelDataType)> = Vec::new(); - for field in fields { - cols.push(( - String::from(field.name()), - types::DaskRelDataType { - name: String::from(field.name()), - sql_type: field.data_type().clone(), - }, - )); - } + let mut cols: Vec<(String, RelDataType)> = Vec::new(); + panic!("uncomment and fix this"); + // for field in fields { + // cols.push(( + // String::from(field.name()), + // RelDataType { + // name: String::from(field.name()), + // sql_type: field.data_type().clone(), + // }, + // )); + // } Some(DaskTable { name: String::from(&table_scan.table_name), diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index bbf872988..3be2537fb 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -1,79 +1,119 @@ use datafusion::arrow::datatypes::{DataType, TimeUnit}; +pub mod rel_data_type; +pub mod rel_data_type_field; + use pyo3::prelude::*; -#[pyclass] -#[derive(Debug, Clone)] -pub struct DaskRelDataType { - pub(crate) name: String, - pub(crate) sql_type: DataType, +/// Enumeration of the type names which can be used to construct a SQL type. Since +/// several SQL types do not exist as Rust types and also because the Enum +/// `SqlTypeName` is already used in the Python Dask-SQL code base this enum is used +/// in place of just using the built-in Rust types. +#[allow(non_camel_case_types)] +#[allow(clippy::upper_case_acronyms)] +enum SqlTypeName { + ANY, + ARRAY, + BIGINT, + BINARY, + BOOLEAN, + CHAR, + COLUMN_LIST, + CURSOR, + DATE, + DECIMAL, + DISTINCT, + DOUBLE, + DYNAMIC_STAR, + FLOAT, + GEOMETRY, + INTEGER, + INTERVAL_DAY, + INTERVAL_DAY_HOUR, + INTERVAL_DAY_MINUTE, + INTERVAL_DAY_SECOND, + INTERVAL_HOUR, + INTERVAL_HOUR_MINUTE, + INTERVAL_HOUR_SECOND, + INTERVAL_MINUTE, + INTERVAL_MINUTE_SECOND, + INTERVAL_MONTH, + INTERVAL_SECOND, + INTERVAL_YEAR, + INTERVAL_YEAR_MONTH, + MAP, + MULTISET, + NULL, + OTHER, + REAL, + ROW, + SARG, + SMALLINT, + STRUCTURED, + SYMBOL, + TIME, + TIME_WITH_LOCAL_TIME_ZONE, + TIMESTAMP, + TIMESTAMP_WITH_LOCAL_TIME_ZONE, + TINYINT, + UNKNOWN, + VARBINARY, + VARCHAR, } -#[pyclass(name = "DataType", module = "datafusion", subclass)] -#[derive(Debug, Clone)] -pub struct PyDataType { - pub data_type: DataType, -} +/// Takes an Arrow DataType (https://docs.rs/crate/arrow/latest/source/src/datatypes/datatype.rs) +/// and converts it to a SQL type. The SQL type is a String and represents the valid +/// SQL types which are supported by Dask-SQL +pub(crate) fn arrow_type_to_sql_type(arrow_type: DataType) -> String { + match arrow_type { + DataType::Null => String::from("NULL"), + DataType::Boolean => String::from("BOOLEAN"), + DataType::Int8 => String::from("TINYINT"), + DataType::UInt8 => String::from("TINYINT"), + DataType::Int16 => String::from("SMALLINT"), + DataType::UInt16 => String::from("SMALLINT"), + DataType::Int32 => String::from("INTEGER"), + DataType::UInt32 => String::from("INTEGER"), + DataType::Int64 => String::from("BIGINT"), + DataType::UInt64 => String::from("BIGINT"), + DataType::Float32 => String::from("FLOAT"), + DataType::Float64 => String::from("DOUBLE"), + DataType::Timestamp(unit, tz) => { + // let mut timestamp_str: String = "timestamp[".to_string(); -impl From for DataType { - fn from(data_type: PyDataType) -> DataType { - data_type.data_type - } -} + // let unit_str: &str = match unit { + // TimeUnit::Microsecond => "ps", + // TimeUnit::Millisecond => "ms", + // TimeUnit::Nanosecond => "ns", + // TimeUnit::Second => "s", + // }; -impl From for PyDataType { - fn from(data_type: DataType) -> PyDataType { - PyDataType { data_type } - } -} + // timestamp_str.push_str(&format!("{}", unit_str)); + // match tz { + // Some(e) => { + // timestamp_str.push_str(&format!(", {}", e)) + // }, + // None => (), + // } + // timestamp_str.push_str("]"); + // println!("timestamp_str: {:?}", timestamp_str); + // timestamp_str -#[pymethods] -impl DaskRelDataType { - #[new] - pub fn new(field_name: String, column_str_sql_type: String) -> Self { - DaskRelDataType { - name: field_name, - sql_type: sql_type_to_arrow_type(column_str_sql_type), + let mut timestamp_str: String = "TIMESTAMP".to_string(); + match tz { + Some(e) => { + timestamp_str.push_str("_WITH_LOCAL_TIME_ZONE(0)"); + } + None => timestamp_str.push_str("(0)"), + } + timestamp_str } - } - - pub fn get_column_name(&self) -> String { - self.name.clone() - } - - pub fn get_type(&self) -> PyDataType { - self.sql_type.clone().into() - } - - pub fn get_type_as_str(&self) -> String { - String::from(arrow_type_to_sql_type(self.sql_type.clone())) - } -} - -/// Takes an Arrow DataType (https://docs.rs/crate/arrow/latest/source/src/datatypes/datatype.rs) -/// and converts it to a SQL type. The SQL type is a String slice and represents the valid -/// SQL types which are supported by Dask-SQL -pub(crate) fn arrow_type_to_sql_type(arrow_type: DataType) -> &'static str { - match arrow_type { - DataType::Null => "NULL", - DataType::Boolean => "BOOLEAN", - DataType::Int8 => "TINYINT", - DataType::UInt8 => "TINYINT", - DataType::Int16 => "SMALLINT", - DataType::UInt16 => "SMALLINT", - DataType::Int32 => "INTEGER", - DataType::UInt32 => "INTEGER", - DataType::Int64 => "BIGINT", - DataType::UInt64 => "BIGINT", - DataType::Float32 => "FLOAT", - DataType::Float64 => "DOUBLE", - DataType::Timestamp { .. } => "TIMESTAMP", - DataType::Date32 => "DATE", - DataType::Date64 => "DATE", - DataType::Time32(..) => "TIMESTAMP", - DataType::Time64(..) => "TIMESTAMP", - DataType::Utf8 => "VARCHAR", - DataType::LargeUtf8 => "BIGVARCHAR", + DataType::Date32 => String::from("DATE"), + DataType::Date64 => String::from("DATE"), + DataType::Time32(..) => String::from("TIMESTAMP"), + DataType::Time64(..) => String::from("TIMESTAMP"), + DataType::Utf8 => String::from("VARCHAR"), + DataType::LargeUtf8 => String::from("BIGVARCHAR"), _ => todo!("Unimplemented Arrow DataType encountered"), } } @@ -81,11 +121,11 @@ pub(crate) fn arrow_type_to_sql_type(arrow_type: DataType) -> &'static str { /// Takes a valid Dask-SQL type and converts that String representation to an instance /// of Arrow DataType (https://docs.rs/crate/arrow/latest/source/src/datatypes/datatype.rs) pub(crate) fn sql_type_to_arrow_type(str_sql_type: String) -> DataType { + println!("str_sql_type: {:?}", str_sql_type); + + // TODO: https://github.com/dask-contrib/dask-sql/issues/485 if str_sql_type.starts_with("timestamp") { - DataType::Timestamp( - TimeUnit::Millisecond, - Some(String::from("America/New_York")), - ) + DataType::Timestamp(TimeUnit::Nanosecond, Some(String::from("Europe/Berlin"))) } else { match &str_sql_type[..] { "NULL" => DataType::Null, @@ -97,11 +137,27 @@ pub(crate) fn sql_type_to_arrow_type(str_sql_type: String) -> DataType { "FLOAT" => DataType::Float32, "DOUBLE" => DataType::Float64, "VARCHAR" => DataType::Utf8, - "TIMESTAMP" => DataType::Timestamp( - TimeUnit::Millisecond, - Some(String::from("America/New_York")), - ), + "TIMESTAMP" => DataType::Timestamp(TimeUnit::Nanosecond, None), + "TIMESTAMP_WITH_LOCAL_TIME_ZONE" => DataType::Timestamp(TimeUnit::Nanosecond, None), _ => todo!("Not yet implemented String value: {:?}", &str_sql_type), } } } + +#[pyclass(name = "DataType", module = "datafusion", subclass)] +#[derive(Debug, Clone)] +pub struct PyDataType { + pub data_type: DataType, +} + +impl From for DataType { + fn from(data_type: PyDataType) -> DataType { + data_type.data_type + } +} + +impl From for PyDataType { + fn from(data_type: DataType) -> PyDataType { + PyDataType { data_type } + } +} diff --git a/dask_planner/src/sql/types/rel_data_type.rs b/dask_planner/src/sql/types/rel_data_type.rs new file mode 100644 index 000000000..7641cc494 --- /dev/null +++ b/dask_planner/src/sql/types/rel_data_type.rs @@ -0,0 +1,111 @@ +use crate::sql::types::rel_data_type_field::RelDataTypeField; + +use std::collections::HashMap; + +use pyo3::prelude::*; + +const PRECISION_NOT_SPECIFIED: i32 = i32::MIN; +const SCALE_NOT_SPECIFIED: i32 = -1; + +/// RelDataType represents the type of a scalar expression or entire row returned from a relational expression. +#[pyclass] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct RelDataType { + nullable: bool, + field_list: Vec, +} + +/// RelDataType represents the type of a scalar expression or entire row returned from a relational expression. +#[pymethods] +impl RelDataType { + /// Looks up a field by name. + /// + /// # Arguments + /// + /// * `field_name` - A String containing the name of the field to find + /// * `case_sensitive` - True if column name matching should be case sensitive and false otherwise + #[pyo3(name = "getField")] + pub fn field(&self, field_name: String, case_sensitive: bool) -> RelDataTypeField { + assert!(!self.field_list.is_empty()); + let field_map: HashMap = self.field_map(); + if case_sensitive && field_map.len() > 0 { + field_map.get(&field_name).unwrap().clone() + } else { + for field in &self.field_list { + #[allow(clippy::if_same_then_else)] + if case_sensitive && field.name().eq(&field_name) { + return field.clone(); + } else if !case_sensitive && field.name().eq_ignore_ascii_case(&field_name) { + return field.clone(); + } + } + + // TODO: Throw a proper error here + panic!( + "Unable to find RelDataTypeField with name {:?} in the RelDataType field_list", + field_name + ); + } + } + + /// Returns a map from field names to fields. + /// + /// # Notes + /// + /// * If several fields have the same name, the map contains the first. + #[pyo3(name = "getFieldMap")] + pub fn field_map(&self) -> HashMap { + let mut fields: HashMap = HashMap::new(); + for field in &self.field_list { + fields.insert(String::from(field.name()), field.clone()); + } + fields + } + + /// Gets the fields in a struct type. The field count is equal to the size of the returned list. + #[pyo3(name = "getFieldList")] + pub fn field_list(&self) -> Vec { + assert!(!self.field_list.is_empty()); + self.field_list.clone() + } + + /// Returns the names of the fields in a struct type. The field count + /// is equal to the size of the returned list. + #[pyo3(name = "getFieldNames")] + pub fn field_names(&self) -> Vec { + assert!(!self.field_list.is_empty()); + let mut field_names: Vec = Vec::new(); + for field in &self.field_list { + field_names.push(String::from(field.name())); + } + field_names + } + + /// Returns the number of fields in a struct type. + #[pyo3(name = "getFieldCount")] + pub fn field_count(&self) -> usize { + assert!(!self.field_list.is_empty()); + self.field_list.len() + } + + #[pyo3(name = "isStruct")] + pub fn is_struct(&self) -> bool { + self.field_list.len() > 0 + } + + /// Queries whether this type allows null values. + #[pyo3(name = "isNullable")] + pub fn is_nullable(&self) -> bool { + self.nullable + } + + #[pyo3(name = "getPrecision")] + pub fn precision(&self) -> i32 { + PRECISION_NOT_SPECIFIED + } + + #[pyo3(name = "getScale")] + pub fn scale(&self) -> i32 { + SCALE_NOT_SPECIFIED + } +} diff --git a/dask_planner/src/sql/types/rel_data_type_field.rs b/dask_planner/src/sql/types/rel_data_type_field.rs new file mode 100644 index 000000000..862e09cf3 --- /dev/null +++ b/dask_planner/src/sql/types/rel_data_type_field.rs @@ -0,0 +1,70 @@ +use crate::sql::types; +use crate::sql::types::rel_data_type::RelDataType; + +use std::fmt; + +use pyo3::prelude::*; + +use super::SqlTypeName; + +/// RelDataTypeField represents the definition of a field in a structured RelDataType. +#[pyclass] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct RelDataTypeField { + name: String, + data_type: RelDataType, + index: u8, +} + +#[pymethods] +impl RelDataTypeField { + #[pyo3(name = "getName")] + pub fn name(&self) -> &str { + &self.name + } + + #[pyo3(name = "getIndex")] + pub fn index(&self) -> u8 { + self.index + } + + #[pyo3(name = "getType")] + pub fn data_type(&self) -> RelDataType { + self.data_type.clone() + } + + /// Since this logic is being ported from Java getKey is synonymous with getName. + /// Alas it is used in certain places so it is implemented here to allow other + /// places in the code base to not have to change. + #[pyo3(name = "getKey")] + pub fn get_key(&self) -> &str { + self.name() + } + + /// Since this logic is being ported from Java getValue is synonymous with getType. + /// Alas it is used in certain places so it is implemented here to allow other + /// places in the code base to not have to change. + #[pyo3(name = "getValue")] + pub fn get_value(&self) -> RelDataType { + self.data_type() + } + + // TODO: Uncomment after implementing in RelDataType + // #[pyo3(name = "isDynamicStar")] + // pub fn is_dynamic_star(&self) -> bool { + // self.data_type.getSqlTypeName() == SqlTypeName.DYNAMIC_STAR + // } +} + +impl fmt::Display for RelDataTypeField { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str("Field: ")?; + fmt.write_str(&self.name)?; + fmt.write_str(" - Index: ")?; + fmt.write_str(&self.index.to_string())?; + // TODO: Uncomment this after implementing the Display trait in RelDataType + // fmt.write_str(" - DataType: ")?; + // fmt.write_str(self.data_type.to_string())?; + Ok(()) + } +} diff --git a/dask_sql/context.py b/dask_sql/context.py index 89912d968..9e08091f8 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -10,7 +10,7 @@ from dask.base import optimize from dask.distributed import Client -from dask_planner.rust import DaskSchema, DaskSQLContext, DaskTable +from dask_planner.rust import DaskSchema, DaskSQLContext, DaskTable, DFParsingException try: import dask_cuda # noqa: F401 @@ -30,6 +30,7 @@ from dask_sql.mappings import python_to_sql_type from dask_sql.physical.rel import RelConverter, custom, logical from dask_sql.physical.rex import RexConverter, core +from dask_sql.utils import ParsingException if TYPE_CHECKING: from dask_planner.rust import Expression @@ -741,9 +742,7 @@ def _prepare_schemas(self): table = DaskTable(name, row_count) df = dc.df - logger.debug( - f"Adding table '{name}' to schema with columns: {list(df.columns)}" - ) + for column in df.columns: data_type = df[column].dtype sql_data_type = python_to_sql_type(data_type) @@ -808,7 +807,11 @@ def _get_ral(self, sql): f"Multiple 'Statements' encountered for SQL {sql}. Please share this with the dev team!" ) - nonOptimizedRel = self.context.logical_relational_algebra(sqlTree[0]) + try: + nonOptimizedRel = self.context.logical_relational_algebra(sqlTree[0]) + except DFParsingException as pe: + raise ParsingException(sql, str(pe)) from None + rel = nonOptimizedRel logger.debug(f"_get_ral -> nonOptimizedRelNode: {nonOptimizedRel}") # # Optimization might remove some alias projects. Make sure to keep them here. diff --git a/dask_sql/physical/rel/base.py b/dask_sql/physical/rel/base.py index 6601a98ac..997940686 100644 --- a/dask_sql/physical/rel/base.py +++ b/dask_sql/physical/rel/base.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: import dask_sql - from dask_planner.rust import DaskRelDataType, DaskTable, LogicalPlan + from dask_planner.rust import LogicalPlan, RelDataType logger = logging.getLogger(__name__) @@ -29,24 +29,26 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> dd.DataFra raise NotImplementedError @staticmethod - def fix_column_to_row_type(cc: ColumnContainer, column_names) -> ColumnContainer: + def fix_column_to_row_type( + cc: ColumnContainer, row_type: "RelDataType" + ) -> ColumnContainer: """ Make sure that the given column container has the column names specified by the row type. We assume that the column order is already correct and will just "blindly" rename the columns. """ - # field_names = [str(x) for x in row_type.getFieldNames()] + field_names = [str(x) for x in row_type.getFieldNames()] - logger.debug(f"Renaming {cc.columns} to {column_names}") + logger.debug(f"Renaming {cc.columns} to {field_names}") - cc = cc.rename(columns=dict(zip(cc.columns, column_names))) + cc = cc.rename(columns=dict(zip(cc.columns, field_names))) # TODO: We can also check for the types here and do any conversions if needed - return cc.limit_to(column_names) + return cc.limit_to(field_names) @staticmethod - def check_columns_from_row_type(df: dd.DataFrame, row_type: "DaskRelDataType"): + def check_columns_from_row_type(df: dd.DataFrame, row_type: "RelDataType"): """ Similar to `self.fix_column_to_row_type`, but this time check for the correct column names instead of @@ -81,7 +83,7 @@ def assert_inputs( return [RelConverter.convert(input_rel, context) for input_rel in input_rels] @staticmethod - def fix_dtype_to_row_type(dc: DataContainer, dask_table: "DaskTable"): + def fix_dtype_to_row_type(dc: DataContainer, row_type: "RelDataType"): """ Fix the dtype of the given data container (or: the df within it) to the data type given as argument. @@ -93,9 +95,17 @@ def fix_dtype_to_row_type(dc: DataContainer, dask_table: "DaskTable"): TODO: we should check the nullability of the SQL type """ df = dc.df + cc = dc.column_container + + field_types = { + int(field.getIndex()): str(field.getType()) + for field in row_type.getFieldList() + } + + for index, field_type in field_types.items(): + expected_type = sql_to_python_type(field_type) + field_name = cc.get_backend_by_frontend_index(index) - for col in dask_table.column_types(): - expected_type = sql_to_python_type(col.get_type_as_str()) - df = cast_column_type(df, col.get_column_name(), expected_type) + df = cast_column_type(df, field_name, expected_type) return DataContainer(df, dc.column_container) diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index 8f2f08218..afc9220d9 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -23,30 +23,30 @@ def test_select_alias(c, df): assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) -def test_select_column(c, df): - result_df = c.sql("SELECT a FROM df") - - assert_eq(result_df, df[["a"]]) - - -def test_select_different_types(c): - expected_df = pd.DataFrame( - { - "date": pd.to_datetime(["2022-01-21 17:34", "2022-01-21", "17:34", pd.NaT]), - "string": ["this is a test", "another test", "äölüć", ""], - "integer": [1, 2, -4, 5], - "float": [-1.1, np.NaN, pd.NA, np.sqrt(2)], - } - ) - c.create_table("df", expected_df) - result_df = c.sql( - """ - SELECT * - FROM df - """ - ) - - assert_eq(result_df, expected_df) +# def test_select_column(c, df): +# result_df = c.sql("SELECT a FROM df") + +# assert_eq(result_df, df[["a"]]) + + +# def test_select_different_types(c): +# expected_df = pd.DataFrame( +# { +# "date": pd.to_datetime(["2022-01-21 17:34", "2022-01-21", "17:34", pd.NaT]), +# "string": ["this is a test", "another test", "äölüć", ""], +# "integer": [1, 2, -4, 5], +# "float": [-1.1, np.NaN, pd.NA, np.sqrt(2)], +# } +# ) +# c.create_table("df", expected_df) +# result_df = c.sql( +# """ +# SELECT * +# FROM df +# """ +# ) + +# assert_eq(result_df, expected_df) @pytest.mark.skip(reason="WIP DataFusion") From d9380a6c460e664ed987f9ef2bde5c57015fa690 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 21 Apr 2022 17:09:33 -0400 Subject: [PATCH 12/61] Include RelDataType work --- dask_planner/src/sql/types/rel_data_type.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dask_planner/src/sql/types/rel_data_type.rs b/dask_planner/src/sql/types/rel_data_type.rs index 7641cc494..3ab59df71 100644 --- a/dask_planner/src/sql/types/rel_data_type.rs +++ b/dask_planner/src/sql/types/rel_data_type.rs @@ -32,10 +32,9 @@ impl RelDataType { field_map.get(&field_name).unwrap().clone() } else { for field in &self.field_list { - #[allow(clippy::if_same_then_else)] - if case_sensitive && field.name().eq(&field_name) { - return field.clone(); - } else if !case_sensitive && field.name().eq_ignore_ascii_case(&field_name) { + if (case_sensitive && field.name().eq(&field_name)) + || (!case_sensitive && field.name().eq_ignore_ascii_case(&field_name)) + { return field.clone(); } } From ad56fc204b73c23d310e76dea41f12e0de417010 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 21 Apr 2022 21:28:09 -0400 Subject: [PATCH 13/61] Introduced SqlTypeName Enum in Rust and mappings for Python --- dask_planner/src/expression.rs | 23 --- dask_planner/src/lib.rs | 1 + dask_planner/src/sql.rs | 14 +- dask_planner/src/sql/logical.rs | 1 + dask_planner/src/sql/table.rs | 72 +++---- dask_planner/src/sql/types.rs | 175 ++++++++---------- dask_planner/src/sql/types/rel_data_type.rs | 8 + .../src/sql/types/rel_data_type_field.rs | 24 ++- dask_sql/datacontainer.py | 1 + dask_sql/mappings.py | 58 +++--- dask_sql/physical/rel/base.py | 9 +- dask_sql/physical/rel/logical/project.py | 43 ++++- dask_sql/physical/rel/logical/table_scan.py | 20 +- 13 files changed, 226 insertions(+), 223 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 31043f118..c0fc34068 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -1,7 +1,5 @@ use crate::sql::logical; -use crate::sql::types::PyDataType; -use pyo3::PyMappingProtocol; use pyo3::{basic::CompareOp, prelude::*, PyNumberProtocol, PyObjectProtocol}; use std::convert::{From, Into}; @@ -389,16 +387,6 @@ impl PyExpr { self.expr.clone().is_null().into() } - pub fn cast(&self, to: PyDataType) -> PyExpr { - // self.expr.cast_to() requires DFSchema to validate that the cast - // is supported, omit that for now - let expr = Expr::Cast { - expr: Box::new(self.expr.clone()), - data_type: to.data_type, - }; - expr.into() - } - /// TODO: I can't express how much I dislike explicity listing all of these methods out /// but PyO3 makes it necessary since its annotations cannot be used in trait impl blocks #[pyo3(name = "getFloat32Value")] @@ -557,14 +545,3 @@ impl PyExpr { } } } - -#[pyproto] -impl PyMappingProtocol for PyExpr { - fn __getitem__(&self, key: &str) -> PyResult { - Ok(Expr::GetIndexedField { - expr: Box::new(self.expr.clone()), - key: ScalarValue::Utf8(Some(key.to_string())), - } - .into()) - } -} diff --git a/dask_planner/src/lib.rs b/dask_planner/src/lib.rs index 35189bace..550175c93 100644 --- a/dask_planner/src/lib.rs +++ b/dask_planner/src/lib.rs @@ -17,6 +17,7 @@ fn rust(py: Python, m: &PyModule) -> PyResult<()> { // Register the python classes m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 9056f36a2..88fa5b901 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -59,16 +59,10 @@ impl ContextProvider for DaskSQLContext { if table.name.eq(&name.table()) { // Build the Schema here let mut fields: Vec = Vec::new(); - - panic!("Uncomment this section .... before running"); - // // Iterate through the DaskTable instance and create a Schema instance - // for (column_name, column_type) in &table.columns { - // fields.push(Field::new( - // column_name, - // column_type.sql_type.clone(), - // false, - // )); - // } + // Iterate through the DaskTable instance and create a Schema instance + for (column_name, column_type) in &table.columns { + fields.push(Field::new(column_name, column_type.to_arrow(), false)); + } resp = Some(Schema::new(fields)); } diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index 283aab317..57ebacfeb 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -86,6 +86,7 @@ impl PyLogicalPlan { /// If the LogicalPlan represents access to a Table that instance is returned /// otherwise None is returned + #[pyo3(name = "getTable")] pub fn table(&mut self) -> PyResult { match table::table_from_logical_plan(&self.current_node()) { Some(table) => Ok(table), diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index 9e0873fce..4bf0ff292 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -1,9 +1,11 @@ use crate::sql::logical; use crate::sql::types::rel_data_type::RelDataType; +use crate::sql::types::rel_data_type_field::RelDataTypeField; +use crate::sql::types::SqlTypeName; use async_trait::async_trait; -use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::datatypes::{DataType, Field, SchemaRef}; pub use datafusion::datasource::TableProvider; use datafusion::error::DataFusionError; use datafusion::physical_plan::{empty::EmptyExec, project_schema, ExecutionPlan}; @@ -93,7 +95,7 @@ pub struct DaskTable { pub(crate) name: String, #[allow(dead_code)] pub(crate) statistics: DaskStatistics, - pub(crate) columns: Vec<(String, RelDataType)>, + pub(crate) columns: Vec<(String, SqlTypeName)>, } #[pymethods] @@ -107,17 +109,15 @@ impl DaskTable { } } - pub fn add_column(&mut self, column_name: String, column_type_str: String) { - panic!("Need to uncomment and fix this before running!!"); - // let sql_type: RelDataType = RelDataType { - // name: String::from(&column_name), - // sql_type: types::sql_type_to_arrow_type(column_type_str), - // }; - - // self.columns.push((column_name, sql_type)); + // TODO: Really wish we could accept a SqlTypeName instance here instead of a String for `column_type` .... + #[pyo3(name = "add_column")] + pub fn add_column(&mut self, column_name: String, column_type: String) { + self.columns + .push((column_name, SqlTypeName::from_string(&column_type))); } - pub fn get_qualified_name(&self, plan: logical::PyLogicalPlan) -> Vec { + #[pyo3(name = "getQualifiedName")] + pub fn qualified_name(&self, plan: logical::PyLogicalPlan) -> Vec { let mut qualified_name = Vec::from([String::from("root")]); match plan.original_plan { @@ -133,28 +133,13 @@ impl DaskTable { qualified_name } - pub fn column_names(&self) -> Vec { - let mut cns: Vec = Vec::new(); - for c in &self.columns { - cns.push(String::from(&c.0)); - } - cns - } - - pub fn column_types(&self) -> Vec { - let mut col_types: Vec = Vec::new(); - for col in &self.columns { - col_types.push(col.1.clone()) + #[pyo3(name = "getRowType")] + pub fn row_type(&self) -> RelDataType { + let mut fields: Vec = Vec::new(); + for (name, data_type) in &self.columns { + fields.push(RelDataTypeField::new(name.clone(), data_type.clone(), 255)); } - col_types - } - - pub fn num_columns(&self) { - println!( - "There are {} columns in table {}", - self.columns.len(), - self.name - ); + RelDataType::new(false, fields) } } @@ -167,19 +152,16 @@ pub(crate) fn table_from_logical_plan(plan: &LogicalPlan) -> Option { // Get the TableProvider for this Table instance let tbl_provider: Arc = table_scan.source.clone(); let tbl_schema: SchemaRef = tbl_provider.schema(); - let fields = tbl_schema.fields(); - - let mut cols: Vec<(String, RelDataType)> = Vec::new(); - panic!("uncomment and fix this"); - // for field in fields { - // cols.push(( - // String::from(field.name()), - // RelDataType { - // name: String::from(field.name()), - // sql_type: field.data_type().clone(), - // }, - // )); - // } + let fields: &Vec = tbl_schema.fields(); + + let mut cols: Vec<(String, SqlTypeName)> = Vec::new(); + for field in fields { + let data_type: &DataType = field.data_type(); + cols.push(( + String::from(field.name()), + SqlTypeName::from_arrow(data_type), + )); + } Some(DaskTable { name: String::from(&table_scan.table_name), diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index 3be2537fb..62f26e9c4 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -1,4 +1,4 @@ -use datafusion::arrow::datatypes::{DataType, TimeUnit}; +use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; pub mod rel_data_type; pub mod rel_data_type_field; @@ -11,7 +11,9 @@ use pyo3::prelude::*; /// in place of just using the built-in Rust types. #[allow(non_camel_case_types)] #[allow(clippy::upper_case_acronyms)] -enum SqlTypeName { +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[pyclass(name = "SqlTypeName", module = "datafusion")] +pub enum SqlTypeName { ANY, ARRAY, BIGINT, @@ -61,103 +63,88 @@ enum SqlTypeName { VARCHAR, } -/// Takes an Arrow DataType (https://docs.rs/crate/arrow/latest/source/src/datatypes/datatype.rs) -/// and converts it to a SQL type. The SQL type is a String and represents the valid -/// SQL types which are supported by Dask-SQL -pub(crate) fn arrow_type_to_sql_type(arrow_type: DataType) -> String { - match arrow_type { - DataType::Null => String::from("NULL"), - DataType::Boolean => String::from("BOOLEAN"), - DataType::Int8 => String::from("TINYINT"), - DataType::UInt8 => String::from("TINYINT"), - DataType::Int16 => String::from("SMALLINT"), - DataType::UInt16 => String::from("SMALLINT"), - DataType::Int32 => String::from("INTEGER"), - DataType::UInt32 => String::from("INTEGER"), - DataType::Int64 => String::from("BIGINT"), - DataType::UInt64 => String::from("BIGINT"), - DataType::Float32 => String::from("FLOAT"), - DataType::Float64 => String::from("DOUBLE"), - DataType::Timestamp(unit, tz) => { - // let mut timestamp_str: String = "timestamp[".to_string(); - - // let unit_str: &str = match unit { - // TimeUnit::Microsecond => "ps", - // TimeUnit::Millisecond => "ms", - // TimeUnit::Nanosecond => "ns", - // TimeUnit::Second => "s", - // }; - - // timestamp_str.push_str(&format!("{}", unit_str)); - // match tz { - // Some(e) => { - // timestamp_str.push_str(&format!(", {}", e)) - // }, - // None => (), - // } - // timestamp_str.push_str("]"); - // println!("timestamp_str: {:?}", timestamp_str); - // timestamp_str - - let mut timestamp_str: String = "TIMESTAMP".to_string(); - match tz { - Some(e) => { - timestamp_str.push_str("_WITH_LOCAL_TIME_ZONE(0)"); - } - None => timestamp_str.push_str("(0)"), - } - timestamp_str +impl SqlTypeName { + pub fn to_arrow(&self) -> DataType { + match self { + SqlTypeName::NULL => DataType::Null, + SqlTypeName::BOOLEAN => DataType::Boolean, + SqlTypeName::TINYINT => DataType::Int8, + SqlTypeName::SMALLINT => DataType::Int16, + SqlTypeName::INTEGER => DataType::Int32, + SqlTypeName::BIGINT => DataType::Int64, + SqlTypeName::REAL => DataType::Float16, + SqlTypeName::FLOAT => DataType::Float32, + SqlTypeName::DOUBLE => DataType::Float64, + SqlTypeName::DATE => DataType::Date64, + SqlTypeName::TIMESTAMP => DataType::Timestamp(TimeUnit::Nanosecond, None), + _ => todo!(), } - DataType::Date32 => String::from("DATE"), - DataType::Date64 => String::from("DATE"), - DataType::Time32(..) => String::from("TIMESTAMP"), - DataType::Time64(..) => String::from("TIMESTAMP"), - DataType::Utf8 => String::from("VARCHAR"), - DataType::LargeUtf8 => String::from("BIGVARCHAR"), - _ => todo!("Unimplemented Arrow DataType encountered"), } -} - -/// Takes a valid Dask-SQL type and converts that String representation to an instance -/// of Arrow DataType (https://docs.rs/crate/arrow/latest/source/src/datatypes/datatype.rs) -pub(crate) fn sql_type_to_arrow_type(str_sql_type: String) -> DataType { - println!("str_sql_type: {:?}", str_sql_type); - // TODO: https://github.com/dask-contrib/dask-sql/issues/485 - if str_sql_type.starts_with("timestamp") { - DataType::Timestamp(TimeUnit::Nanosecond, Some(String::from("Europe/Berlin"))) - } else { - match &str_sql_type[..] { - "NULL" => DataType::Null, - "BOOLEAN" => DataType::Boolean, - "TINYINT" => DataType::Int8, - "SMALLINT" => DataType::Int16, - "INTEGER" => DataType::Int32, - "BIGINT" => DataType::Int64, - "FLOAT" => DataType::Float32, - "DOUBLE" => DataType::Float64, - "VARCHAR" => DataType::Utf8, - "TIMESTAMP" => DataType::Timestamp(TimeUnit::Nanosecond, None), - "TIMESTAMP_WITH_LOCAL_TIME_ZONE" => DataType::Timestamp(TimeUnit::Nanosecond, None), - _ => todo!("Not yet implemented String value: {:?}", &str_sql_type), + pub fn from_arrow(data_type: &DataType) -> Self { + match data_type { + DataType::Null => SqlTypeName::NULL, + DataType::Boolean => SqlTypeName::BOOLEAN, + DataType::Int8 => SqlTypeName::TINYINT, + DataType::Int16 => SqlTypeName::SMALLINT, + DataType::Int32 => SqlTypeName::INTEGER, + DataType::Int64 => SqlTypeName::BIGINT, + DataType::UInt8 => SqlTypeName::TINYINT, + DataType::UInt16 => SqlTypeName::SMALLINT, + DataType::UInt32 => SqlTypeName::INTEGER, + DataType::UInt64 => SqlTypeName::BIGINT, + DataType::Float16 => SqlTypeName::REAL, + DataType::Float32 => SqlTypeName::FLOAT, + DataType::Float64 => SqlTypeName::DOUBLE, + DataType::Timestamp(_unit, tz) => match tz { + Some(..) => SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE, + None => SqlTypeName::TIMESTAMP, + }, + DataType::Date32 => SqlTypeName::DATE, + DataType::Date64 => SqlTypeName::DATE, + DataType::Interval(unit) => match unit { + IntervalUnit::DayTime => SqlTypeName::INTERVAL_DAY, + IntervalUnit::YearMonth => SqlTypeName::INTERVAL_YEAR_MONTH, + IntervalUnit::MonthDayNano => SqlTypeName::INTERVAL_MONTH, + }, + DataType::Binary => SqlTypeName::BINARY, + DataType::FixedSizeBinary(_size) => SqlTypeName::VARBINARY, + DataType::Utf8 => SqlTypeName::CHAR, + DataType::LargeUtf8 => SqlTypeName::VARCHAR, + DataType::Struct(_fields) => SqlTypeName::STRUCTURED, + DataType::Decimal(_precision, _scale) => SqlTypeName::DECIMAL, + DataType::Map(_field, _bool) => SqlTypeName::MAP, + _ => todo!(), } } -} - -#[pyclass(name = "DataType", module = "datafusion", subclass)] -#[derive(Debug, Clone)] -pub struct PyDataType { - pub data_type: DataType, -} - -impl From for DataType { - fn from(data_type: PyDataType) -> DataType { - data_type.data_type - } -} -impl From for PyDataType { - fn from(data_type: DataType) -> PyDataType { - PyDataType { data_type } + pub fn from_string(input_type: &str) -> Self { + match input_type { + "SqlTypeName.NULL" => SqlTypeName::NULL, + "SqlTypeName.BOOLEAN" => SqlTypeName::BOOLEAN, + "SqlTypeName.TINYINT" => SqlTypeName::TINYINT, + "SqlTypeName.SMALLINT" => SqlTypeName::SMALLINT, + "SqlTypeName.INTEGER" => SqlTypeName::INTEGER, + "SqlTypeName.BIGINT" => SqlTypeName::BIGINT, + "SqlTypeName.REAL" => SqlTypeName::REAL, + "SqlTypeName.FLOAT" => SqlTypeName::FLOAT, + "SqlTypeName.DOUBLE" => SqlTypeName::DOUBLE, + "SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE" => { + SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE + } + "SqlTypeName.TIMESTAMP" => SqlTypeName::TIMESTAMP, + "SqlTypeName.DATE" => SqlTypeName::DATE, + "SqlTypeName.INTERVAL_DAY" => SqlTypeName::INTERVAL_DAY, + "SqlTypeName.INTERVAL_YEAR_MONTH" => SqlTypeName::INTERVAL_YEAR_MONTH, + "SqlTypeName.INTERVAL_MONTH" => SqlTypeName::INTERVAL_MONTH, + "SqlTypeName.BINARY" => SqlTypeName::BINARY, + "SqlTypeName.VARBINARY" => SqlTypeName::VARBINARY, + "SqlTypeName.CHAR" => SqlTypeName::CHAR, + "SqlTypeName.VARCHAR" => SqlTypeName::VARCHAR, + "SqlTypeName.STRUCTURED" => SqlTypeName::STRUCTURED, + "SqlTypeName.DECIMAL" => SqlTypeName::DECIMAL, + "SqlTypeName.MAP" => SqlTypeName::MAP, + _ => todo!(), + } } } diff --git a/dask_planner/src/sql/types/rel_data_type.rs b/dask_planner/src/sql/types/rel_data_type.rs index 3ab59df71..c0e8b594a 100644 --- a/dask_planner/src/sql/types/rel_data_type.rs +++ b/dask_planner/src/sql/types/rel_data_type.rs @@ -18,6 +18,14 @@ pub struct RelDataType { /// RelDataType represents the type of a scalar expression or entire row returned from a relational expression. #[pymethods] impl RelDataType { + #[new] + pub fn new(nullable: bool, fields: Vec) -> Self { + Self { + nullable: nullable, + field_list: fields, + } + } + /// Looks up a field by name. /// /// # Arguments diff --git a/dask_planner/src/sql/types/rel_data_type_field.rs b/dask_planner/src/sql/types/rel_data_type_field.rs index 862e09cf3..1a01e32d0 100644 --- a/dask_planner/src/sql/types/rel_data_type_field.rs +++ b/dask_planner/src/sql/types/rel_data_type_field.rs @@ -1,23 +1,30 @@ -use crate::sql::types; use crate::sql::types::rel_data_type::RelDataType; +use crate::sql::types::SqlTypeName; use std::fmt; use pyo3::prelude::*; -use super::SqlTypeName; - /// RelDataTypeField represents the definition of a field in a structured RelDataType. #[pyclass] #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct RelDataTypeField { name: String, - data_type: RelDataType, + data_type: SqlTypeName, index: u8, } #[pymethods] impl RelDataTypeField { + #[new] + pub fn new(name: String, data_type: SqlTypeName, index: u8) -> Self { + Self { + name: name, + data_type: data_type, + index: index, + } + } + #[pyo3(name = "getName")] pub fn name(&self) -> &str { &self.name @@ -29,7 +36,7 @@ impl RelDataTypeField { } #[pyo3(name = "getType")] - pub fn data_type(&self) -> RelDataType { + pub fn data_type(&self) -> SqlTypeName { self.data_type.clone() } @@ -45,10 +52,15 @@ impl RelDataTypeField { /// Alas it is used in certain places so it is implemented here to allow other /// places in the code base to not have to change. #[pyo3(name = "getValue")] - pub fn get_value(&self) -> RelDataType { + pub fn get_value(&self) -> SqlTypeName { self.data_type() } + #[pyo3(name = "setValue")] + pub fn set_value(&mut self, data_type: SqlTypeName) { + self.data_type = data_type + } + // TODO: Uncomment after implementing in RelDataType // #[pyo3(name = "isDynamicStar")] // pub fn is_dynamic_star(&self) -> bool { diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py index db77c9dfc..e92f5b6e3 100644 --- a/dask_sql/datacontainer.py +++ b/dask_sql/datacontainer.py @@ -121,6 +121,7 @@ def get_backend_by_frontend_index(self, index: int) -> str: Get back the dask column, which is referenced by the frontend (SQL) column with the given index. """ + print(f"self._frontend_columns: {self._frontend_columns} index: {index}") frontend_column = self._frontend_columns[index] backend_column = self._frontend_backend_mapping[frontend_column] return backend_column diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index edd919085..2d6a04069 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -7,35 +7,36 @@ import numpy as np import pandas as pd +from dask_planner.rust import SqlTypeName from dask_sql._compat import FLOAT_NAN_IMPLEMENTED logger = logging.getLogger(__name__) - +# Default mapping between python types and SQL types _PYTHON_TO_SQL = { - np.float64: "DOUBLE", - np.float32: "FLOAT", - np.int64: "BIGINT", - pd.Int64Dtype(): "BIGINT", - np.int32: "INTEGER", - pd.Int32Dtype(): "INTEGER", - np.int16: "SMALLINT", - pd.Int16Dtype(): "SMALLINT", - np.int8: "TINYINT", - pd.Int8Dtype(): "TINYINT", - np.uint64: "BIGINT", - pd.UInt64Dtype(): "BIGINT", - np.uint32: "INTEGER", - pd.UInt32Dtype(): "INTEGER", - np.uint16: "SMALLINT", - pd.UInt16Dtype(): "SMALLINT", - np.uint8: "TINYINT", - pd.UInt8Dtype(): "TINYINT", - np.bool8: "BOOLEAN", - pd.BooleanDtype(): "BOOLEAN", - np.object_: "VARCHAR", - pd.StringDtype(): "VARCHAR", - np.datetime64: "TIMESTAMP", + np.float64: SqlTypeName.DOUBLE, + np.float32: SqlTypeName.FLOAT, + np.int64: SqlTypeName.BIGINT, + pd.Int64Dtype(): SqlTypeName.BIGINT, + np.int32: SqlTypeName.INTEGER, + pd.Int32Dtype(): SqlTypeName.INTEGER, + np.int16: SqlTypeName.SMALLINT, + pd.Int16Dtype(): SqlTypeName.SMALLINT, + np.int8: SqlTypeName.TINYINT, + pd.Int8Dtype(): SqlTypeName.TINYINT, + np.uint64: SqlTypeName.BIGINT, + pd.UInt64Dtype(): SqlTypeName.BIGINT, + np.uint32: SqlTypeName.INTEGER, + pd.UInt32Dtype(): SqlTypeName.INTEGER, + np.uint16: SqlTypeName.SMALLINT, + pd.UInt16Dtype(): SqlTypeName.SMALLINT, + np.uint8: SqlTypeName.TINYINT, + pd.UInt8Dtype(): SqlTypeName.TINYINT, + np.bool8: SqlTypeName.BOOLEAN, + pd.BooleanDtype(): SqlTypeName.BOOLEAN, + np.object_: SqlTypeName.VARCHAR, + pd.StringDtype(): SqlTypeName.VARCHAR, + np.datetime64: SqlTypeName.TIMESTAMP, } if FLOAT_NAN_IMPLEMENTED: # pragma: no cover @@ -81,13 +82,13 @@ } -def python_to_sql_type(python_type): +def python_to_sql_type(python_type) -> "SqlTypeName": """Mapping between python and SQL types.""" if isinstance(python_type, np.dtype): python_type = python_type.type if pd.api.types.is_datetime64tz_dtype(python_type): - return "TIMESTAMP_WITH_LOCAL_TIME_ZONE" + return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE try: return _PYTHON_TO_SQL[python_type] @@ -193,7 +194,10 @@ def sql_to_python_value(sql_type: str, literal_value: Any) -> Any: def sql_to_python_type(sql_type: str) -> type: """Turn an SQL type into a dataframe dtype""" - logger.debug(f"mappings.sql_to_python_type() -> sql_type: {sql_type}") + # Ex: Rust SqlTypeName Enum str value 'SqlTypeName.DOUBLE' + if sql_type.find(".") != -1: + sql_type = sql_type.split(".")[1] + if sql_type.startswith("CHAR(") or sql_type.startswith("VARCHAR("): return pd.StringDtype() elif sql_type.startswith("INTERVAL"): diff --git a/dask_sql/physical/rel/base.py b/dask_sql/physical/rel/base.py index 997940686..d49797f21 100644 --- a/dask_sql/physical/rel/base.py +++ b/dask_sql/physical/rel/base.py @@ -38,6 +38,7 @@ def fix_column_to_row_type( We assume that the column order is already correct and will just "blindly" rename the columns. """ + print(f"type(row_type): {type(row_type)}") field_names = [str(x) for x in row_type.getFieldNames()] logger.debug(f"Renaming {cc.columns} to {field_names}") @@ -98,14 +99,14 @@ def fix_dtype_to_row_type(dc: DataContainer, row_type: "RelDataType"): cc = dc.column_container field_types = { - int(field.getIndex()): str(field.getType()) + str(field.getName()): str(field.getType()) for field in row_type.getFieldList() } - for index, field_type in field_types.items(): + for sql_field_name, field_type in field_types.items(): expected_type = sql_to_python_type(field_type) - field_name = cc.get_backend_by_frontend_index(index) + df_field_name = cc.get_backend_by_frontend_name(sql_field_name) - df = cast_column_type(df, field_name, expected_type) + df = cast_column_type(df, df_field_name, expected_type) return DataContainer(df, dc.column_container) diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index 4cbf87e6b..ddac9fe66 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -29,8 +29,12 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai df = dc.df cc = dc.column_container + # Collect all (new) columns + # named_projects = rel.getNamedProjects() + column_names = [] - new_columns, new_mappings = {}, {} + new_columns = {} + new_mappings = {} projection = rel.projection() @@ -82,3 +86,40 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai dc = DataContainer(df, cc) dc = self.fix_dtype_to_row_type(dc, rel.table()) return dc + + # for expr, key in named_projects: + # key = str(key) + # column_names.append(key) + + # # shortcut: if we have a column already, there is no need to re-assign it again + # # this is only the case if the expr is a RexInputRef + # if isinstance(expr, org.apache.calcite.rex.RexInputRef): + # index = expr.getIndex() + # backend_column_name = cc.get_backend_by_frontend_index(index) + # logger.debug( + # f"Not re-adding the same column {key} (but just referencing it)" + # ) + # new_mappings[key] = backend_column_name + # else: + # random_name = new_temporary_column(df) + # new_columns[random_name] = RexConverter.convert( + # expr, dc, context=context + # ) + # logger.debug(f"Adding a new column {key} out of {expr}") + # new_mappings[key] = random_name + + # # Actually add the new columns + # if new_columns: + # df = df.assign(**new_columns) + + # # and the new mappings + # for key, backend_column_name in new_mappings.items(): + # cc = cc.add(key, backend_column_name) + + # # Make sure the order is correct + # cc = cc.limit_to(column_names) + + # cc = self.fix_column_to_row_type(cc, rel.getRowType()) + # dc = DataContainer(df, cc) + # dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) + # return dc diff --git a/dask_sql/physical/rel/logical/table_scan.py b/dask_sql/physical/rel/logical/table_scan.py index 8a6375e9a..f9c614746 100644 --- a/dask_sql/physical/rel/logical/table_scan.py +++ b/dask_sql/physical/rel/logical/table_scan.py @@ -34,32 +34,26 @@ def convert( self.assert_inputs(rel, 0) # The table(s) we need to return - table = rel.table() - field_names = rel.get_field_names() + table = rel.getTable() # The table names are all names split by "." # We assume to always have the form something.something - table_names = [str(n) for n in table.get_qualified_name(rel)] + table_names = [str(n) for n in table.getQualifiedName(rel)] assert len(table_names) == 2 schema_name = table_names[0] table_name = table_names[1] table_name = table_name.lower() - logger.debug( - f"table_scan.convert() -> schema_name: {schema_name} - table_name: {table_name}" - ) - dc = context.schema[schema_name].tables[table_name] df = dc.df cc = dc.column_container # Make sure we only return the requested columns - # row_type = table.getRowType() - # field_specifications = [str(f) for f in row_type.getFieldNames()] - # cc = cc.limit_to(field_specifications) - cc = cc.limit_to(field_names) + row_type = table.getRowType() + field_specifications = [str(f) for f in row_type.getFieldNames()] + cc = cc.limit_to(field_specifications) - cc = self.fix_column_to_row_type(cc, table.column_names()) + cc = self.fix_column_to_row_type(cc, row_type) dc = DataContainer(df, cc) - dc = self.fix_dtype_to_row_type(dc, table) + dc = self.fix_dtype_to_row_type(dc, row_type) return dc From 7b20e6699cee69e249dfbb67d1c86d889207b580 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 22 Apr 2022 15:40:44 -0400 Subject: [PATCH 14/61] impl PyExpr.getIndex() --- dask_planner/src/expression.rs | 144 ++++++++++----------- dask_planner/src/lib.rs | 1 + dask_planner/src/sql/logical/aggregate.rs | 6 +- dask_planner/src/sql/logical/filter.rs | 4 +- dask_planner/src/sql/logical/projection.rs | 79 +++-------- dask_planner/src/sql/types.rs | 9 ++ dask_sql/datacontainer.py | 1 - dask_sql/physical/rel/base.py | 1 - dask_sql/physical/rel/logical/aggregate.py | 6 +- dask_sql/physical/rel/logical/project.py | 91 ++++--------- dask_sql/physical/rex/convert.py | 2 +- 11 files changed, 134 insertions(+), 210 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index c0fc34068..c9b1e58b5 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -1,8 +1,11 @@ use crate::sql::logical; +use crate::sql::types::RexType; -use pyo3::{basic::CompareOp, prelude::*, PyNumberProtocol, PyObjectProtocol}; +use pyo3::prelude::*; use std::convert::{From, Into}; +use datafusion::error::DataFusionError; + use datafusion::arrow::datatypes::DataType; use datafusion_expr::{col, lit, BuiltinScalarFunction, Expr}; @@ -10,10 +13,14 @@ use datafusion::scalar::ScalarValue; pub use datafusion_expr::LogicalPlan; +use std::sync::Arc; + + /// An PyExpr that can be used on a DataFrame #[pyclass(name = "Expression", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyExpr { + pub input_plan: Option>, pub expr: Expr, } @@ -25,7 +32,10 @@ impl From for Expr { impl From for PyExpr { fn from(expr: Expr) -> PyExpr { - PyExpr { expr } + PyExpr { + input_plan: None, + expr: expr , + } } } @@ -47,61 +57,16 @@ impl From for PyScalarValue { } } -#[pyproto] -impl PyNumberProtocol for PyExpr { - fn __add__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok((lhs.expr + rhs.expr).into()) - } - - fn __sub__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok((lhs.expr - rhs.expr).into()) - } - - fn __truediv__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok((lhs.expr / rhs.expr).into()) - } - - fn __mul__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok((lhs.expr * rhs.expr).into()) - } - - fn __mod__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok(lhs.expr.modulus(rhs.expr).into()) - } - - fn __and__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok(lhs.expr.and(rhs.expr).into()) - } - - fn __or__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok(lhs.expr.or(rhs.expr).into()) - } - - fn __invert__(&self) -> PyResult { - Ok(self.expr.clone().not().into()) - } -} - -#[pyproto] -impl PyObjectProtocol for PyExpr { - fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr { - let expr = match op { - CompareOp::Lt => self.expr.clone().lt(other.expr), - CompareOp::Le => self.expr.clone().lt_eq(other.expr), - CompareOp::Eq => self.expr.clone().eq(other.expr), - CompareOp::Ne => self.expr.clone().not_eq(other.expr), - CompareOp::Gt => self.expr.clone().gt(other.expr), - CompareOp::Ge => self.expr.clone().gt_eq(other.expr), - }; - expr.into() - } +impl PyExpr { - fn __str__(&self) -> PyResult { - Ok(format!("{}", self.expr)) + /// Generally we would implement the `From` trait offered by Rust + /// However in this case Expr does not contain the contextual + /// `LogicalPlan` instance that we need so we need to make a instance + /// function to take and create the PyExpr. + pub fn from(expr: Expr, input: Option>) -> PyExpr { + PyExpr { input_plan: input, expr:expr } } -} -impl PyExpr { fn _column_name(&self, mut plan: LogicalPlan) -> String { match &self.expr { Expr::Alias(expr, name) => { @@ -205,10 +170,32 @@ impl PyExpr { } } + /// Gets the positional index of the Expr instance from the LogicalPlan DFSchema + #[pyo3(name = "getIndex")] + pub fn index(&self) -> PyResult { + let input: &Option> = &self.input_plan; + match input { + Some(plan) => { + let name: Result = self.expr.name(plan.schema()); + println!("Column NAME: {:?}", name); + match name { + Ok(k) => { + let index: usize = plan.schema().index_of(&k).unwrap(); + println!("Index: {:?}", index); + Ok(index) + }, + Err(e) => panic!("{:?}", e), + } + }, + None => panic!("We need a valid LogicalPlan instance to get the Expr's index in the schema"), + } + } + /// Examine the current/"self" PyExpr and return its "type" /// In this context a "type" is what Dask-SQL Python /// RexConverter plugin instance should be invoked to handle /// the Rex conversion + #[pyo3(name = "getExprType")] pub fn get_expr_type(&self) -> String { String::from(match &self.expr { Expr::Alias(..) => "Alias", @@ -236,6 +223,35 @@ impl PyExpr { }) } + /// Determines the type of this Expr based on its variant + #[pyo3(name = "getRexType")] + pub fn rex_type(&self) -> RexType { + match &self.expr { + Expr::Alias(..) => RexType::Reference, + Expr::Column(..) => RexType::Reference, + Expr::ScalarVariable(..) => RexType::Literal, + Expr::Literal(..) => RexType::Literal, + Expr::BinaryExpr { .. } => RexType::Call, + Expr::Not(..) => RexType::Call, + Expr::IsNotNull(..) => RexType::Call, + Expr::Negative(..) => RexType::Call, + Expr::GetIndexedField { .. } => RexType::Reference, + Expr::IsNull(..) => RexType::Call, + Expr::Between { .. } => RexType::Call, + Expr::Case { .. } => RexType::Call, + Expr::Cast { .. } => RexType::Call, + Expr::TryCast { .. } => RexType::Call, + Expr::Sort { .. } => RexType::Call, + Expr::ScalarFunction { .. } => RexType::Call, + Expr::AggregateFunction { .. } => RexType::Call, + Expr::WindowFunction { .. } => RexType::Call, + Expr::AggregateUDF { .. } => RexType::Call, + Expr::InList { .. } => RexType::Call, + Expr::Wildcard => RexType::Call, + _ => RexType::Other, + } + } + /// Python friendly shim code to get the name of a column referenced by an expression pub fn column_name(&self, mut plan: logical::PyLogicalPlan) -> String { self._column_name(plan.current_node()) @@ -367,26 +383,6 @@ impl PyExpr { } } - #[staticmethod] - pub fn column(value: &str) -> PyExpr { - col(value).into() - } - - /// assign a name to the PyExpr - pub fn alias(&self, name: &str) -> PyExpr { - self.expr.clone().alias(name).into() - } - - /// Create a sort PyExpr from an existing PyExpr. - #[args(ascending = true, nulls_first = true)] - pub fn sort(&self, ascending: bool, nulls_first: bool) -> PyExpr { - self.expr.clone().sort(ascending, nulls_first).into() - } - - pub fn is_null(&self) -> PyExpr { - self.expr.clone().is_null().into() - } - /// TODO: I can't express how much I dislike explicity listing all of these methods out /// but PyO3 makes it necessary since its annotations cannot be used in trait impl blocks #[pyo3(name = "getFloat32Value")] diff --git a/dask_planner/src/lib.rs b/dask_planner/src/lib.rs index 550175c93..937971904 100644 --- a/dask_planner/src/lib.rs +++ b/dask_planner/src/lib.rs @@ -18,6 +18,7 @@ fn rust(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/dask_planner/src/sql/logical/aggregate.rs b/dask_planner/src/sql/logical/aggregate.rs index e260e4bd7..af50ce0d5 100644 --- a/dask_planner/src/sql/logical/aggregate.rs +++ b/dask_planner/src/sql/logical/aggregate.rs @@ -18,7 +18,7 @@ impl PyAggregate { pub fn group_expressions(&self) -> PyResult> { let mut group_exprs: Vec = Vec::new(); for expr in &self.aggregate.group_expr { - group_exprs.push(expr.clone().into()); + group_exprs.push(PyExpr::from(expr.clone(), Some(self.aggregate.input.clone()))); } Ok(group_exprs) } @@ -27,7 +27,7 @@ impl PyAggregate { pub fn agg_expressions(&self) -> PyResult> { let mut agg_exprs: Vec = Vec::new(); for expr in &self.aggregate.aggr_expr { - agg_exprs.push(expr.clone().into()); + agg_exprs.push(PyExpr::from(expr.clone(), Some(self.aggregate.input.clone()))); } Ok(agg_exprs) } @@ -46,7 +46,7 @@ impl PyAggregate { Expr::AggregateFunction { fun: _, args, .. } => { let mut exprs: Vec = Vec::new(); for expr in args { - exprs.push(PyExpr { expr }); + exprs.push(PyExpr { input_plan: Some(self.aggregate.input.clone()), expr: expr }); } exprs } diff --git a/dask_planner/src/sql/logical/filter.rs b/dask_planner/src/sql/logical/filter.rs index 2ef163721..d0266dbc4 100644 --- a/dask_planner/src/sql/logical/filter.rs +++ b/dask_planner/src/sql/logical/filter.rs @@ -16,14 +16,14 @@ impl PyFilter { /// LogicalPlan::Filter: The PyExpr, predicate, that represents the filtering condition #[pyo3(name = "getCondition")] pub fn get_condition(&mut self) -> PyResult { - Ok(self.filter.predicate.clone().into()) + Ok(PyExpr::from(self.filter.predicate.clone(), Some(self.filter.input.clone()))) } } impl From for PyFilter { fn from(logical_plan: LogicalPlan) -> PyFilter { match logical_plan { - LogicalPlan::Filter(filter) => PyFilter { filter }, + LogicalPlan::Filter(filter) => PyFilter { filter: filter }, _ => panic!("something went wrong here"), } } diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index d5ef65827..2d24c5fae 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -14,7 +14,7 @@ pub struct PyProjection { #[pymethods] impl PyProjection { #[pyo3(name = "getColumnName")] - fn named_projects(&mut self, expr: PyExpr) -> PyResult { + fn column_name(&mut self, expr: PyExpr) -> PyResult { let mut val: String = String::from("OK"); match expr.expr { Expr::Alias(expr, _alias) => match expr.as_ref() { @@ -38,9 +38,10 @@ impl PyProjection { _ => unimplemented!(), } } - _ => println!("not supported: {:?}", expr), + _ => panic!("not supported: {:?}", expr), }, - _ => println!("Ignore for now"), + Expr::Column(col) => val = col.name.clone(), + _ => panic!("Ignore for now"), } Ok(val) } @@ -50,72 +51,30 @@ impl PyProjection { fn projected_expressions(&mut self) -> PyResult> { let mut projs: Vec = Vec::new(); for expr in &self.projection.expr { - projs.push(expr.clone().into()); + projs.push(PyExpr::from(expr.clone(), Some(self.projection.input.clone()))); } Ok(projs) } - // fn named_projects(&mut self) { - // for expr in &self.projection.expr { - // match expr { - // Expr::Alias(expr, alias) => { - // match expr.as_ref() { - // Expr::Column(col) => { - // let index = self.projection.input.schema().index_of_column(&col).unwrap(); - // println!("projection column '{}' maps to input column {}", col.to_string(), index); - // let f: &DFField = self.projection.input.schema().field(index); - // println!("Field: {:?}", f); - // match self.projection.input.as_ref() { - // LogicalPlan::Aggregate(agg) => { - // let mut exprs = agg.group_expr.clone(); - // exprs.extend_from_slice(&agg.aggr_expr); - // match &exprs[index] { - // Expr::AggregateFunction { args, .. } => { - // match &args[0] { - // Expr::Column(col) => { - // println!("AGGREGATE COLUMN IS {}", col.name); - // }, - // _ => unimplemented!() - // } - // }, - // _ => unimplemented!() - // } - // }, - // _ => unimplemented!() - // } - // } - // _ => unimplemented!() - // } - // }, - // _ => println!("not supported: {:?}", expr) - // } - // } - // } - - // fn named_projects(&mut self) { - // match self.projection.input.as_ref() { - // LogicalPlan::Aggregate(agg) => { - // match &agg.aggr_expr[0] { - // Expr::AggregateFunction { args, .. } => { - // match &args[0] { - // Expr::Column(col) => { - // println!("AGGREGATE COLUMN IS {}", col.name); - // }, - // _ => unimplemented!() - // } - // }, - // _ => println!("ignore for now") - // } - // }, - // _ => unimplemented!() - // } - // } + #[pyo3(name = "getNamedProjects")] + fn named_projects(&mut self) -> PyResult> { + let mut named: Vec<(String, PyExpr)> = Vec::new(); + for expr in &self.projected_expressions().unwrap() { + let name: String = self.column_name(expr.clone()).unwrap(); + named.push((name, expr.clone())); + } + Ok(named) + } } impl From for PyProjection { fn from(logical_plan: LogicalPlan) -> PyProjection { match logical_plan { - LogicalPlan::Projection(projection) => PyProjection { projection }, + LogicalPlan::Projection(projection) => { + PyProjection { + projection: projection + } + }, _ => panic!("something went wrong here"), } } diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index 62f26e9c4..21ca73282 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -5,6 +5,15 @@ pub mod rel_data_type_field; use pyo3::prelude::*; +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[pyclass(name = "RexType", module = "datafusion")] +pub enum RexType { + Literal, + Call, + Reference, + Other, +} + /// Enumeration of the type names which can be used to construct a SQL type. Since /// several SQL types do not exist as Rust types and also because the Enum /// `SqlTypeName` is already used in the Python Dask-SQL code base this enum is used diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py index e92f5b6e3..db77c9dfc 100644 --- a/dask_sql/datacontainer.py +++ b/dask_sql/datacontainer.py @@ -121,7 +121,6 @@ def get_backend_by_frontend_index(self, index: int) -> str: Get back the dask column, which is referenced by the frontend (SQL) column with the given index. """ - print(f"self._frontend_columns: {self._frontend_columns} index: {index}") frontend_column = self._frontend_columns[index] backend_column = self._frontend_backend_mapping[frontend_column] return backend_column diff --git a/dask_sql/physical/rel/base.py b/dask_sql/physical/rel/base.py index d49797f21..801d2d84c 100644 --- a/dask_sql/physical/rel/base.py +++ b/dask_sql/physical/rel/base.py @@ -38,7 +38,6 @@ def fix_column_to_row_type( We assume that the column order is already correct and will just "blindly" rename the columns. """ - print(f"type(row_type): {type(row_type)}") field_names = [str(x) for x in row_type.getFieldNames()] logger.debug(f"Renaming {cc.columns} to {field_names}") diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index 6e0214ab2..19f05ab11 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -299,11 +299,11 @@ def _collect_aggregations( for expr in rel.aggregate().getNamedAggCalls(): logger.debug(f"Aggregate Call: {expr}") - logger.debug(f"Expr Type: {expr.get_expr_type()}") + logger.debug(f"Expr Type: {expr.getExprType()}") # Determine the aggregation function to use assert ( - expr.get_expr_type() == "AggregateFunction" + expr.getExprType() == "AggregateFunction" ), "Do not know how to handle this case!" # TODO: Generally we need a way to capture the current SQL schema here in case this is a custom aggregation function @@ -315,7 +315,7 @@ def _collect_aggregations( inputs = rel.aggregate().getArgs(expr) logger.debug(f"Number of Inputs: {len(inputs)}") logger.debug( - f"Input: {inputs[0]} of type: {inputs[0].get_expr_type()} with column name: {inputs[0].column_name(rel)}" + f"Input: {inputs[0]} of type: {inputs[0].getExprType()} with column name: {inputs[0].column_name(rel)}" ) # TODO: This if statement is likely no longer needed but left here for the time being just in case diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index ddac9fe66..c943ca2ae 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -6,6 +6,8 @@ from dask_sql.physical.rex import RexConverter from dask_sql.utils import new_temporary_column +from dask_planner.rust import RexType + if TYPE_CHECKING: import dask_sql from dask_planner.rust import LogicalPlan @@ -30,47 +32,43 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai cc = dc.column_container # Collect all (new) columns - # named_projects = rel.getNamedProjects() + proj = rel.projection() + named_projects = proj.getNamedProjects() column_names = [] new_columns = {} new_mappings = {} - projection = rel.projection() - # Collect all (new) columns this Projection will limit to - for expr in projection.getProjectedExpressions(): + for key, expr in named_projects: - key = str(expr.column_name(rel)) + key = str(key) column_names.append(key) - # TODO: Temporarily assigning all new rows to increase the flexibility of the code base, - # later it will be added back it is just too early in the process right now to be feasible - - # # shortcut: if we have a column already, there is no need to re-assign it again - # # this is only the case if the expr is a RexInputRef - # if isinstance(expr, org.apache.calcite.rex.RexInputRef): - # index = expr.getIndex() - # backend_column_name = cc.get_backend_by_frontend_index(index) - # logger.debug( - # f"Not re-adding the same column {key} (but just referencing it)" - # ) - # new_mappings[key] = backend_column_name - # else: - # random_name = new_temporary_column(df) - # new_columns[random_name] = RexConverter.convert( - # expr, dc, context=context - # ) - # logger.debug(f"Adding a new column {key} out of {expr}") - # new_mappings[key] = random_name - random_name = new_temporary_column(df) new_columns[random_name] = RexConverter.convert( rel, expr, dc, context=context ) - logger.debug(f"Adding a new column {key} out of {expr}") + new_mappings[key] = random_name + # shortcut: if we have a column already, there is no need to re-assign it again + # this is only the case if the expr is a RexInputRef + if expr.getRexType() == RexType.Reference: + index = expr.getIndex() + backend_column_name = cc.get_backend_by_frontend_index(index) + logger.debug( + f"Not re-adding the same column {key} (but just referencing it)" + ) + new_mappings[key] = backend_column_name + else: + random_name = new_temporary_column(df) + new_columns[random_name] = RexConverter.convert( + expr, dc, context=context + ) + logger.debug(f"Adding a new column {key} out of {expr}") + new_mappings[key] = random_name + # Actually add the new columns if new_columns: df = df.assign(**new_columns) @@ -82,44 +80,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai # Make sure the order is correct cc = cc.limit_to(column_names) - cc = self.fix_column_to_row_type(cc, column_names) + cc = self.fix_column_to_row_type(cc, rel.getRowType()) dc = DataContainer(df, cc) - dc = self.fix_dtype_to_row_type(dc, rel.table()) + dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) return dc - - # for expr, key in named_projects: - # key = str(key) - # column_names.append(key) - - # # shortcut: if we have a column already, there is no need to re-assign it again - # # this is only the case if the expr is a RexInputRef - # if isinstance(expr, org.apache.calcite.rex.RexInputRef): - # index = expr.getIndex() - # backend_column_name = cc.get_backend_by_frontend_index(index) - # logger.debug( - # f"Not re-adding the same column {key} (but just referencing it)" - # ) - # new_mappings[key] = backend_column_name - # else: - # random_name = new_temporary_column(df) - # new_columns[random_name] = RexConverter.convert( - # expr, dc, context=context - # ) - # logger.debug(f"Adding a new column {key} out of {expr}") - # new_mappings[key] = random_name - - # # Actually add the new columns - # if new_columns: - # df = df.assign(**new_columns) - - # # and the new mappings - # for key, backend_column_name in new_mappings.items(): - # cc = cc.add(key, backend_column_name) - - # # Make sure the order is correct - # cc = cc.limit_to(column_names) - - # cc = self.fix_column_to_row_type(cc, rel.getRowType()) - # dc = DataContainer(df, cc) - # dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) - # return dc diff --git a/dask_sql/physical/rex/convert.py b/dask_sql/physical/rex/convert.py index f1a54c145..1123e8359 100644 --- a/dask_sql/physical/rex/convert.py +++ b/dask_sql/physical/rex/convert.py @@ -60,7 +60,7 @@ def convert( using the stored plugins and the dictionary of registered dask tables. """ - expr_type = _REX_TYPE_TO_PLUGIN[rex.get_expr_type()] + expr_type = _REX_TYPE_TO_PLUGIN[rex.getExprType()] try: plugin_instance = cls.get_plugin(expr_type) From 7dd2017afb67dd3bed0ee8f16135f7628d0c1982 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 22 Apr 2022 18:45:18 -0400 Subject: [PATCH 15/61] add getRowType() for logical.rs --- dask_planner/src/expression.rs | 29 ++++++++++--------- dask_planner/src/sql/logical.rs | 19 ++++++++++++ dask_planner/src/sql/logical/aggregate.rs | 15 ++++++++-- dask_planner/src/sql/logical/filter.rs | 5 +++- dask_planner/src/sql/logical/projection.rs | 11 +++---- .../src/sql/types/rel_data_type_field.rs | 21 +++++++++++--- 6 files changed, 74 insertions(+), 26 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index c9b1e58b5..4f589f1d1 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -13,8 +13,9 @@ use datafusion::scalar::ScalarValue; pub use datafusion_expr::LogicalPlan; -use std::sync::Arc; +use datafusion::prelude::Column; +use std::sync::Arc; /// An PyExpr that can be used on a DataFrame #[pyclass(name = "Expression", module = "datafusion", subclass)] @@ -34,7 +35,7 @@ impl From for PyExpr { fn from(expr: Expr) -> PyExpr { PyExpr { input_plan: None, - expr: expr , + expr: expr, } } } @@ -58,13 +59,15 @@ impl From for PyScalarValue { } impl PyExpr { - /// Generally we would implement the `From` trait offered by Rust - /// However in this case Expr does not contain the contextual + /// However in this case Expr does not contain the contextual /// `LogicalPlan` instance that we need so we need to make a instance /// function to take and create the PyExpr. pub fn from(expr: Expr, input: Option>) -> PyExpr { - PyExpr { input_plan: input, expr:expr } + PyExpr { + input_plan: input, + expr: expr, + } } fn _column_name(&self, mut plan: LogicalPlan) -> String { @@ -177,17 +180,17 @@ impl PyExpr { match input { Some(plan) => { let name: Result = self.expr.name(plan.schema()); - println!("Column NAME: {:?}", name); match name { - Ok(k) => { - let index: usize = plan.schema().index_of(&k).unwrap(); - println!("Index: {:?}", index); - Ok(index) - }, + Ok(fq_name) => Ok(plan + .schema() + .index_of_column(&Column::from_qualified_name(&fq_name)) + .unwrap()), Err(e) => panic!("{:?}", e), } - }, - None => panic!("We need a valid LogicalPlan instance to get the Expr's index in the schema"), + } + None => { + panic!("We need a valid LogicalPlan instance to get the Expr's index in the schema") + } } } diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index 57ebacfeb..9028b2aba 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -1,4 +1,7 @@ use crate::sql::table; +use crate::sql::types::rel_data_type::RelDataType; +use crate::sql::types::rel_data_type_field::RelDataTypeField; + mod aggregate; mod filter; mod join; @@ -133,6 +136,22 @@ impl PyLogicalPlan { pub fn explain_current(&mut self) -> PyResult { Ok(format!("{}", self.current_node().display_indent())) } + + #[pyo3(name = "getRowType")] + pub fn row_type(&self) -> RelDataType { + let fields: &Vec = self.original_plan.schema().fields(); + let mut rel_fields: Vec = Vec::new(); + for i in 0..fields.len() { + rel_fields.push( + RelDataTypeField::from( + fields[i].clone(), + self.original_plan.schema().as_ref().clone(), + ) + .unwrap(), + ); + } + RelDataType::new(false, rel_fields) + } } impl From for LogicalPlan { diff --git a/dask_planner/src/sql/logical/aggregate.rs b/dask_planner/src/sql/logical/aggregate.rs index af50ce0d5..726a73552 100644 --- a/dask_planner/src/sql/logical/aggregate.rs +++ b/dask_planner/src/sql/logical/aggregate.rs @@ -18,7 +18,10 @@ impl PyAggregate { pub fn group_expressions(&self) -> PyResult> { let mut group_exprs: Vec = Vec::new(); for expr in &self.aggregate.group_expr { - group_exprs.push(PyExpr::from(expr.clone(), Some(self.aggregate.input.clone()))); + group_exprs.push(PyExpr::from( + expr.clone(), + Some(self.aggregate.input.clone()), + )); } Ok(group_exprs) } @@ -27,7 +30,10 @@ impl PyAggregate { pub fn agg_expressions(&self) -> PyResult> { let mut agg_exprs: Vec = Vec::new(); for expr in &self.aggregate.aggr_expr { - agg_exprs.push(PyExpr::from(expr.clone(), Some(self.aggregate.input.clone()))); + agg_exprs.push(PyExpr::from( + expr.clone(), + Some(self.aggregate.input.clone()), + )); } Ok(agg_exprs) } @@ -46,7 +52,10 @@ impl PyAggregate { Expr::AggregateFunction { fun: _, args, .. } => { let mut exprs: Vec = Vec::new(); for expr in args { - exprs.push(PyExpr { input_plan: Some(self.aggregate.input.clone()), expr: expr }); + exprs.push(PyExpr { + input_plan: Some(self.aggregate.input.clone()), + expr: expr, + }); } exprs } diff --git a/dask_planner/src/sql/logical/filter.rs b/dask_planner/src/sql/logical/filter.rs index d0266dbc4..4474ad1c6 100644 --- a/dask_planner/src/sql/logical/filter.rs +++ b/dask_planner/src/sql/logical/filter.rs @@ -16,7 +16,10 @@ impl PyFilter { /// LogicalPlan::Filter: The PyExpr, predicate, that represents the filtering condition #[pyo3(name = "getCondition")] pub fn get_condition(&mut self) -> PyResult { - Ok(PyExpr::from(self.filter.predicate.clone(), Some(self.filter.input.clone()))) + Ok(PyExpr::from( + self.filter.predicate.clone(), + Some(self.filter.input.clone()), + )) } } diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index 2d24c5fae..4ea9b21d6 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -51,7 +51,10 @@ impl PyProjection { fn projected_expressions(&mut self) -> PyResult> { let mut projs: Vec = Vec::new(); for expr in &self.projection.expr { - projs.push(PyExpr::from(expr.clone(), Some(self.projection.input.clone()))); + projs.push(PyExpr::from( + expr.clone(), + Some(self.projection.input.clone()), + )); } Ok(projs) } @@ -70,10 +73,8 @@ impl PyProjection { impl From for PyProjection { fn from(logical_plan: LogicalPlan) -> PyProjection { match logical_plan { - LogicalPlan::Projection(projection) => { - PyProjection { - projection: projection - } + LogicalPlan::Projection(projection) => PyProjection { + projection: projection, }, _ => panic!("something went wrong here"), } diff --git a/dask_planner/src/sql/types/rel_data_type_field.rs b/dask_planner/src/sql/types/rel_data_type_field.rs index 1a01e32d0..d2a0823d4 100644 --- a/dask_planner/src/sql/types/rel_data_type_field.rs +++ b/dask_planner/src/sql/types/rel_data_type_field.rs @@ -1,6 +1,8 @@ -use crate::sql::types::rel_data_type::RelDataType; use crate::sql::types::SqlTypeName; +use datafusion::error::DataFusionError; +use datafusion::logical_plan::{DFField, DFSchema}; + use std::fmt; use pyo3::prelude::*; @@ -11,13 +13,24 @@ use pyo3::prelude::*; pub struct RelDataTypeField { name: String, data_type: SqlTypeName, - index: u8, + index: usize, +} + +// Functions that should not be presented to Python are placed here +impl RelDataTypeField { + pub fn from(field: DFField, schema: DFSchema) -> Result { + Ok(RelDataTypeField { + name: field.name().clone(), + data_type: SqlTypeName::from_arrow(field.data_type()), + index: schema.index_of(field.name())?, + }) + } } #[pymethods] impl RelDataTypeField { #[new] - pub fn new(name: String, data_type: SqlTypeName, index: u8) -> Self { + pub fn new(name: String, data_type: SqlTypeName, index: usize) -> Self { Self { name: name, data_type: data_type, @@ -31,7 +44,7 @@ impl RelDataTypeField { } #[pyo3(name = "getIndex")] - pub fn index(&self) -> u8 { + pub fn index(&self) -> usize { self.index } From 984f5238245fc4bda0431d30cf140883b1eea52e Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 22 Apr 2022 22:14:58 -0400 Subject: [PATCH 16/61] Introduce DaskTypeMap for storing correlating SqlTypeName and DataTypes --- dask_planner/Cargo.toml | 2 +- dask_planner/src/lib.rs | 1 + dask_planner/src/sql.rs | 12 +- dask_planner/src/sql/logical.rs | 1 + dask_planner/src/sql/table.rs | 12 +- dask_planner/src/sql/types.rs | 105 ++++++++++++--- .../src/sql/types/rel_data_type_field.rs | 18 ++- dask_sql/context.py | 10 +- dask_sql/mappings.py | 29 ++-- dask_sql/physical/rel/base.py | 14 +- dask_sql/physical/rel/logical/project.py | 10 +- dask_sql/physical/rel/logical/table_scan.py | 9 +- tests/integration/test_select.py | 127 +++++++++--------- 13 files changed, 220 insertions(+), 130 deletions(-) diff --git a/dask_planner/Cargo.toml b/dask_planner/Cargo.toml index c66fc637c..bc7e3138a 100644 --- a/dask_planner/Cargo.toml +++ b/dask_planner/Cargo.toml @@ -11,7 +11,7 @@ rust-version = "1.59" [dependencies] tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } rand = "0.7" -pyo3 = { version = "0.15", features = ["extension-module", "abi3", "abi3-py38"] } +pyo3 = { version = "0.16", features = ["extension-module", "abi3", "abi3-py38"] } datafusion = { git="https://github.com/apache/arrow-datafusion/", rev = "23f1c77569d1f3b0ff42ade56f9b2afb53d44292" } datafusion-expr = { git="https://github.com/apache/arrow-datafusion/", rev = "23f1c77569d1f3b0ff42ade56f9b2afb53d44292" } uuid = { version = "0.8", features = ["v4"] } diff --git a/dask_planner/src/lib.rs b/dask_planner/src/lib.rs index 937971904..43b27b3b1 100644 --- a/dask_planner/src/lib.rs +++ b/dask_planner/src/lib.rs @@ -19,6 +19,7 @@ fn rust(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 88fa5b901..05650bbb1 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -61,7 +61,7 @@ impl ContextProvider for DaskSQLContext { let mut fields: Vec = Vec::new(); // Iterate through the DaskTable instance and create a Schema instance for (column_name, column_type) in &table.columns { - fields.push(Field::new(column_name, column_type.to_arrow(), false)); + fields.push(Field::new(column_name, column_type.data_type(), false)); } resp = Some(Schema::new(fields)); @@ -148,10 +148,7 @@ impl DaskSQLContext { ); Ok(statements) } - Err(e) => Err(PyErr::new::(format!( - "{}", - e - ))), + Err(e) => Err(PyErr::new::(format!("{}", e))), } } @@ -170,10 +167,7 @@ impl DaskSQLContext { current_node: None, }) } - Err(e) => Err(PyErr::new::(format!( - "{}", - e - ))), + Err(e) => Err(PyErr::new::(format!("{}", e))), } } } diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index 9028b2aba..0c9ca27f5 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -1,6 +1,7 @@ use crate::sql::table; use crate::sql::types::rel_data_type::RelDataType; use crate::sql::types::rel_data_type_field::RelDataTypeField; +use datafusion::logical_plan::DFField; mod aggregate; mod filter; diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index 4bf0ff292..eebc6ff7f 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -1,6 +1,7 @@ use crate::sql::logical; use crate::sql::types::rel_data_type::RelDataType; use crate::sql::types::rel_data_type_field::RelDataTypeField; +use crate::sql::types::DaskTypeMap; use crate::sql::types::SqlTypeName; use async_trait::async_trait; @@ -95,7 +96,7 @@ pub struct DaskTable { pub(crate) name: String, #[allow(dead_code)] pub(crate) statistics: DaskStatistics, - pub(crate) columns: Vec<(String, SqlTypeName)>, + pub(crate) columns: Vec<(String, DaskTypeMap)>, } #[pymethods] @@ -111,9 +112,8 @@ impl DaskTable { // TODO: Really wish we could accept a SqlTypeName instance here instead of a String for `column_type` .... #[pyo3(name = "add_column")] - pub fn add_column(&mut self, column_name: String, column_type: String) { - self.columns - .push((column_name, SqlTypeName::from_string(&column_type))); + pub fn add_column(&mut self, column_name: String, type_map: DaskTypeMap) { + self.columns.push((column_name, type_map)); } #[pyo3(name = "getQualifiedName")] @@ -154,12 +154,12 @@ pub(crate) fn table_from_logical_plan(plan: &LogicalPlan) -> Option { let tbl_schema: SchemaRef = tbl_provider.schema(); let fields: &Vec = tbl_schema.fields(); - let mut cols: Vec<(String, SqlTypeName)> = Vec::new(); + let mut cols: Vec<(String, DaskTypeMap)> = Vec::new(); for field in fields { let data_type: &DataType = field.data_type(); cols.push(( String::from(field.name()), - SqlTypeName::from_arrow(data_type), + DaskTypeMap::from(SqlTypeName::from_arrow(data_type), data_type.clone()), )); } diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index 21ca73282..cd2f100a8 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -4,6 +4,8 @@ pub mod rel_data_type; pub mod rel_data_type_field; use pyo3::prelude::*; +use pyo3::types::PyAny; +use pyo3::types::PyDict; #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[pyclass(name = "RexType", module = "datafusion")] @@ -14,6 +16,92 @@ pub enum RexType { Other, } +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[pyclass(name = "DaskTypeMap", module = "datafusion", subclass)] +/// Represents a Python Data Type. This is needed instead of simple +/// Enum instances because PyO3 can only support unit variants as +/// of version 0.16 which means Enums like `DataType::TIMESTAMP_WITH_LOCAL_TIME_ZONE` +/// which generally hold `unit` and `tz` information are unable to +/// do that so data is lost. This struct aims to solve that issue +/// by taking the type Enum from Python and some optional extra +/// parameters that can be used to properly create those DataType +/// instances in Rust. +pub struct DaskTypeMap { + sql_type: SqlTypeName, + data_type: DataType, +} + +/// Functions not exposed to Python +impl DaskTypeMap { + pub fn from(sql_type: SqlTypeName, data_type: DataType) -> Self { + DaskTypeMap { + sql_type: sql_type, + data_type: data_type, + } + } + + pub fn data_type(&self) -> DataType { + self.data_type.clone() + } +} + +#[pymethods] +impl DaskTypeMap { + #[new] + #[args(sql_type, py_kwargs = "**")] + fn new(sql_type: SqlTypeName, py_kwargs: Option<&PyDict>) -> Self { + println!("sql_type={:?} - py_kwargs={:?}", sql_type, py_kwargs); + + let d_type: DataType = match sql_type { + SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE => { + let (unit, tz) = match py_kwargs { + Some(dict) => { + let tz: Option = match dict.get_item("tz") { + Some(e) => { + let res: PyResult = e.extract(); + Some(res.unwrap()) + } + None => None, + }; + let unit: TimeUnit = match dict.get_item("unit") { + Some(e) => { + let res: PyResult<&str> = e.extract(); + match res.unwrap() { + "Second" => TimeUnit::Second, + "Millisecond" => TimeUnit::Millisecond, + "Microsecond" => TimeUnit::Microsecond, + "Nanosecond" => TimeUnit::Nanosecond, + _ => TimeUnit::Nanosecond, + } + } + // Default to Nanosecond which is common if not present + None => TimeUnit::Nanosecond, + }; + (unit, tz) + } + // Default to Nanosecond and None for tz which is common if not present + None => (TimeUnit::Nanosecond, None), + }; + DataType::Timestamp(unit, tz) + } + _ => { + panic!("stop here"); + // sql_type.to_arrow() + } + }; + + DaskTypeMap { + sql_type: sql_type, + data_type: d_type, + } + } + + #[pyo3(name = "getSqlType")] + pub fn sql_type(&self) -> SqlTypeName { + self.sql_type.clone() + } +} + /// Enumeration of the type names which can be used to construct a SQL type. Since /// several SQL types do not exist as Rust types and also because the Enum /// `SqlTypeName` is already used in the Python Dask-SQL code base this enum is used @@ -73,23 +161,6 @@ pub enum SqlTypeName { } impl SqlTypeName { - pub fn to_arrow(&self) -> DataType { - match self { - SqlTypeName::NULL => DataType::Null, - SqlTypeName::BOOLEAN => DataType::Boolean, - SqlTypeName::TINYINT => DataType::Int8, - SqlTypeName::SMALLINT => DataType::Int16, - SqlTypeName::INTEGER => DataType::Int32, - SqlTypeName::BIGINT => DataType::Int64, - SqlTypeName::REAL => DataType::Float16, - SqlTypeName::FLOAT => DataType::Float32, - SqlTypeName::DOUBLE => DataType::Float64, - SqlTypeName::DATE => DataType::Date64, - SqlTypeName::TIMESTAMP => DataType::Timestamp(TimeUnit::Nanosecond, None), - _ => todo!(), - } - } - pub fn from_arrow(data_type: &DataType) -> Self { match data_type { DataType::Null => SqlTypeName::NULL, diff --git a/dask_planner/src/sql/types/rel_data_type_field.rs b/dask_planner/src/sql/types/rel_data_type_field.rs index d2a0823d4..754b93f42 100644 --- a/dask_planner/src/sql/types/rel_data_type_field.rs +++ b/dask_planner/src/sql/types/rel_data_type_field.rs @@ -1,3 +1,4 @@ +use crate::sql::types::DaskTypeMap; use crate::sql::types::SqlTypeName; use datafusion::error::DataFusionError; @@ -12,7 +13,7 @@ use pyo3::prelude::*; #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct RelDataTypeField { name: String, - data_type: SqlTypeName, + data_type: DaskTypeMap, index: usize, } @@ -21,7 +22,10 @@ impl RelDataTypeField { pub fn from(field: DFField, schema: DFSchema) -> Result { Ok(RelDataTypeField { name: field.name().clone(), - data_type: SqlTypeName::from_arrow(field.data_type()), + data_type: DaskTypeMap { + sql_type: SqlTypeName::from_arrow(field.data_type()), + data_type: field.data_type().clone(), + }, index: schema.index_of(field.name())?, }) } @@ -30,10 +34,10 @@ impl RelDataTypeField { #[pymethods] impl RelDataTypeField { #[new] - pub fn new(name: String, data_type: SqlTypeName, index: usize) -> Self { + pub fn new(name: String, type_map: DaskTypeMap, index: usize) -> Self { Self { name: name, - data_type: data_type, + data_type: type_map, index: index, } } @@ -49,7 +53,7 @@ impl RelDataTypeField { } #[pyo3(name = "getType")] - pub fn data_type(&self) -> SqlTypeName { + pub fn data_type(&self) -> DaskTypeMap { self.data_type.clone() } @@ -65,12 +69,12 @@ impl RelDataTypeField { /// Alas it is used in certain places so it is implemented here to allow other /// places in the code base to not have to change. #[pyo3(name = "getValue")] - pub fn get_value(&self) -> SqlTypeName { + pub fn get_value(&self) -> DaskTypeMap { self.data_type() } #[pyo3(name = "setValue")] - pub fn set_value(&mut self, data_type: SqlTypeName) { + pub fn set_value(&mut self, data_type: DaskTypeMap) { self.data_type = data_type } diff --git a/dask_sql/context.py b/dask_sql/context.py index 9e08091f8..47670cf74 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -10,7 +10,13 @@ from dask.base import optimize from dask.distributed import Client -from dask_planner.rust import DaskSchema, DaskSQLContext, DaskTable, DFParsingException +from dask_planner.rust import ( + DaskSchema, + DaskSQLContext, + DaskTable, + DaskTypeMap, + DFParsingException, +) try: import dask_cuda # noqa: F401 @@ -746,7 +752,7 @@ def _prepare_schemas(self): for column in df.columns: data_type = df[column].dtype sql_data_type = python_to_sql_type(data_type) - table.add_column(column, str(sql_data_type)) + table.add_column(column, sql_data_type) rust_schema.add_table(table) diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 2d6a04069..da35e1e2b 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd -from dask_planner.rust import SqlTypeName +from dask_planner.rust import DaskTypeMap, SqlTypeName from dask_sql._compat import FLOAT_NAN_IMPLEMENTED logger = logging.getLogger(__name__) @@ -82,16 +82,20 @@ } -def python_to_sql_type(python_type) -> "SqlTypeName": +def python_to_sql_type(python_type) -> "DaskTypeMap": """Mapping between python and SQL types.""" if isinstance(python_type, np.dtype): python_type = python_type.type if pd.api.types.is_datetime64tz_dtype(python_type): - return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE + return DaskTypeMap( + SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE, + unit=str(python_type.unit), + tz=str(python_type.tz), + ) try: - return _PYTHON_TO_SQL[python_type] + return DaskTypeMap(_PYTHON_TO_SQL[python_type]) except KeyError: # pragma: no cover raise NotImplementedError( f"The python type {python_type} is not implemented (yet)" @@ -198,24 +202,20 @@ def sql_to_python_type(sql_type: str) -> type: if sql_type.find(".") != -1: sql_type = sql_type.split(".")[1] + print(f"sql_type: {sql_type}") if sql_type.startswith("CHAR(") or sql_type.startswith("VARCHAR("): return pd.StringDtype() elif sql_type.startswith("INTERVAL"): return np.dtype(" bool: TODO: nullability is not checked so far. """ + print(f"similar_type: {lhs} - {rhs}") pdt = pd.api.types is_uint = pdt.is_unsigned_integer_dtype is_sint = pdt.is_signed_integer_dtype @@ -273,9 +274,7 @@ def cast_column_type( """ current_type = df[column_name].dtype - logger.debug( - f"Column {column_name} has type {current_type}, expecting {expected_type}..." - ) + print(f"Column {column_name} has type {current_type}, expecting {expected_type}...") casted_column = cast_column_to_type(df[column_name], expected_type) @@ -303,5 +302,5 @@ def cast_column_to_type(col: dd.Series, expected_type: str): # will convert both NA and np.NaN to NA. col = da.trunc(col.fillna(value=np.NaN)) - logger.debug(f"Need to cast from {current_type} to {expected_type}") + print(f"Need to cast from {current_type} to {expected_type}") return col.astype(expected_type) diff --git a/dask_sql/physical/rel/base.py b/dask_sql/physical/rel/base.py index 801d2d84c..5b6807937 100644 --- a/dask_sql/physical/rel/base.py +++ b/dask_sql/physical/rel/base.py @@ -98,14 +98,18 @@ def fix_dtype_to_row_type(dc: DataContainer, row_type: "RelDataType"): cc = dc.column_container field_types = { - str(field.getName()): str(field.getType()) - for field in row_type.getFieldList() + str(field.getName()): field.getType() for field in row_type.getFieldList() } - for sql_field_name, field_type in field_types.items(): - expected_type = sql_to_python_type(field_type) - df_field_name = cc.get_backend_by_frontend_name(sql_field_name) + for field_name, field_type in field_types.items(): + expected_type = sql_to_python_type(str(field_type.getSqlType())) + df_field_name = cc.get_backend_by_frontend_name(field_name) + print( + f"Before cast df_field_name: {df_field_name}, expected_type: {expected_type}" + ) + print(f"Before cast: {df.head(10)}") df = cast_column_type(df, df_field_name, expected_type) + print(f"After cast: {df.head(10)}") return DataContainer(df, dc.column_container) diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index c943ca2ae..c33054442 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -1,13 +1,12 @@ import logging from typing import TYPE_CHECKING +from dask_planner.rust import RexType from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rex import RexConverter from dask_sql.utils import new_temporary_column -from dask_planner.rust import RexType - if TYPE_CHECKING: import dask_sql from dask_planner.rust import LogicalPlan @@ -31,6 +30,8 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai df = dc.df cc = dc.column_container + print(f"Before Project: {df.head(10)}") + # Collect all (new) columns proj = rel.projection() named_projects = proj.getNamedProjects() @@ -49,7 +50,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai new_columns[random_name] = RexConverter.convert( rel, expr, dc, context=context ) - + new_mappings[key] = random_name # shortcut: if we have a column already, there is no need to re-assign it again @@ -83,4 +84,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai cc = self.fix_column_to_row_type(cc, rel.getRowType()) dc = DataContainer(df, cc) dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) + + print(f"After Project: {dc.df.head(10)}") + return dc diff --git a/dask_sql/physical/rel/logical/table_scan.py b/dask_sql/physical/rel/logical/table_scan.py index f9c614746..4b50deb11 100644 --- a/dask_sql/physical/rel/logical/table_scan.py +++ b/dask_sql/physical/rel/logical/table_scan.py @@ -48,12 +48,17 @@ def convert( df = dc.df cc = dc.column_container + print(f"Before TableScan: {df.head(10)}") + # Make sure we only return the requested columns row_type = table.getRowType() field_specifications = [str(f) for f in row_type.getFieldNames()] cc = cc.limit_to(field_specifications) - cc = self.fix_column_to_row_type(cc, row_type) + cc = self.fix_column_to_row_type(cc, rel.getRowType()) dc = DataContainer(df, cc) - dc = self.fix_dtype_to_row_type(dc, row_type) + print(f"Before TableScan fix dtype: {dc.df.head(10)}") + print(f"row_type: {row_type}") + dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) + print(f"After TableScan fix dtype: {dc.df.head(10)}") return dc diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index afc9220d9..2182457d6 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -5,22 +5,21 @@ from dask_sql.utils import ParsingException from tests.utils import assert_eq +# def test_select(c, df): +# result_df = c.sql("SELECT * FROM df") -def test_select(c, df): - result_df = c.sql("SELECT * FROM df") +# assert_eq(result_df, df) - assert_eq(result_df, df) +# @pytest.mark.skip(reason="WIP DataFusion") +# def test_select_alias(c, df): +# result_df = c.sql("SELECT a as b, b as a FROM df") -@pytest.mark.skip(reason="WIP DataFusion") -def test_select_alias(c, df): - result_df = c.sql("SELECT a as b, b as a FROM df") - - expected_df = pd.DataFrame(index=df.index) - expected_df["b"] = df.a - expected_df["a"] = df.b +# expected_df = pd.DataFrame(index=df.index) +# expected_df["b"] = df.a +# expected_df["a"] = df.b - assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) +# assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) # def test_select_column(c, df): @@ -49,70 +48,69 @@ def test_select_alias(c, df): # assert_eq(result_df, expected_df) -@pytest.mark.skip(reason="WIP DataFusion") -def test_select_expr(c, df): - result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") - result_df = result_df - - expected_df = pd.DataFrame( - { - "a": df["a"] + 1, - "bla": df["b"], - '"df"."a" - 1': df["a"] - 1, - } - ) - assert_eq(result_df, expected_df) +# @pytest.mark.skip(reason="WIP DataFusion") +# def test_select_expr(c, df): +# result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") +# result_df = result_df +# expected_df = pd.DataFrame( +# { +# "a": df["a"] + 1, +# "bla": df["b"], +# '"df"."a" - 1': df["a"] - 1, +# } +# ) +# assert_eq(result_df, expected_df) -@pytest.mark.skip( - reason="WIP DataFusion, subquery - https://github.com/apache/arrow-datafusion/issues/2237" -) -def test_select_of_select(c, df): - result_df = c.sql( - """ - SELECT 2*c AS e, d - 1 AS f - FROM - ( - SELECT a - 1 AS c, 2*b AS d - FROM df - ) AS "inner" - """ - ) - expected_df = pd.DataFrame({"e": 2 * (df["a"] - 1), "f": 2 * df["b"] - 1}) - assert_eq(result_df, expected_df) +# @pytest.mark.skip( +# reason="WIP DataFusion, subquery - https://github.com/apache/arrow-datafusion/issues/2237" +# ) +# def test_select_of_select(c, df): +# result_df = c.sql( +# """ +# SELECT 2*c AS e, d - 1 AS f +# FROM +# ( +# SELECT a - 1 AS c, 2*b AS d +# FROM df +# ) AS "inner" +# """ +# ) +# expected_df = pd.DataFrame({"e": 2 * (df["a"] - 1), "f": 2 * df["b"] - 1}) +# assert_eq(result_df, expected_df) -@pytest.mark.skip(reason="WIP DataFusion") -def test_select_of_select_with_casing(c, df): - result_df = c.sql( - """ - SELECT AAA, aaa, aAa - FROM - ( - SELECT a - 1 AS aAa, 2*b AS aaa, a + b AS AAA - FROM df - ) AS "inner" - """ - ) - expected_df = pd.DataFrame( - {"AAA": df["a"] + df["b"], "aaa": 2 * df["b"], "aAa": df["a"] - 1} - ) +# @pytest.mark.skip(reason="WIP DataFusion") +# def test_select_of_select_with_casing(c, df): +# result_df = c.sql( +# """ +# SELECT AAA, aaa, aAa +# FROM +# ( +# SELECT a - 1 AS aAa, 2*b AS aaa, a + b AS AAA +# FROM df +# ) AS "inner" +# """ +# ) - assert_eq(result_df, expected_df) +# expected_df = pd.DataFrame( +# {"AAA": df["a"] + df["b"], "aaa": 2 * df["b"], "aAa": df["a"] - 1} +# ) +# assert_eq(result_df, expected_df) -@pytest.mark.skip(reason="WIP DataFusion") -def test_wrong_input(c): - with pytest.raises(ParsingException): - c.sql("""SELECT x FROM df""") - with pytest.raises(ParsingException): - c.sql("""SELECT x FROM df""") +# def test_wrong_input(c): +# with pytest.raises(ParsingException): +# c.sql("""SELECT x FROM df""") +# with pytest.raises(ParsingException): +# c.sql("""SELECT x FROM df""") -@pytest.mark.skip(reason="WIP DataFusion") + +# @pytest.mark.skip(reason="WIP DataFusion") def test_timezones(c, datetime_table): result_df = c.sql( """ @@ -120,6 +118,9 @@ def test_timezones(c, datetime_table): """ ) + print(f"Expected DF: \n{datetime_table.head(10)}\n") + print(f"\nResult DF: \n{result_df.head(10)}") + assert_eq(result_df, datetime_table) From 1405fea3f65146bf1bfe70c6fb1311333a469db6 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Sat, 23 Apr 2022 18:55:25 -0400 Subject: [PATCH 17/61] use str values instead of Rust Enums, Python is unable to Hash the Rust Enums if used in a dict --- dask_planner/src/sql/logical/projection.rs | 14 +- dask_planner/src/sql/types.rs | 55 ++++++- dask_sql/mappings.py | 54 +++---- dask_sql/physical/rel/base.py | 4 +- dask_sql/physical/rel/logical/table_scan.py | 5 - tests/integration/test_select.py | 164 ++++++++++---------- 6 files changed, 173 insertions(+), 123 deletions(-) diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index 4ea9b21d6..3d2eccdd8 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -30,18 +30,24 @@ impl PyProjection { println!("AGGREGATE COLUMN IS {}", col.name); val = col.name.clone(); } - _ => unimplemented!(), + _ => unimplemented!("projection.rs column_name is unimplemented for Expr variant: {:?}", &args[0]), }, - _ => unimplemented!(), + _ => unimplemented!("projection.rs column_name is unimplemented for Expr variant: {:?}", &exprs[index]), } } - _ => unimplemented!(), + LogicalPlan::TableScan(table_scan) => val = table_scan.table_name.clone(), + _ => unimplemented!("projection.rs column_name is unimplemented for LogicalPlan variant: {:?}", self.projection.input), } } _ => panic!("not supported: {:?}", expr), }, Expr::Column(col) => val = col.name.clone(), - _ => panic!("Ignore for now"), + _ => { + panic!( + "column_name is unimplemented for Expr variant: {:?}", + expr.expr + ); + } } Ok(val) } diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index cd2f100a8..f439f82e9 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -84,9 +84,40 @@ impl DaskTypeMap { }; DataType::Timestamp(unit, tz) } + SqlTypeName::TIMESTAMP => { + let (unit, tz) = match py_kwargs { + Some(dict) => { + let tz: Option = match dict.get_item("tz") { + Some(e) => { + let res: PyResult = e.extract(); + Some(res.unwrap()) + } + None => None, + }; + let unit: TimeUnit = match dict.get_item("unit") { + Some(e) => { + let res: PyResult<&str> = e.extract(); + match res.unwrap() { + "Second" => TimeUnit::Second, + "Millisecond" => TimeUnit::Millisecond, + "Microsecond" => TimeUnit::Microsecond, + "Nanosecond" => TimeUnit::Nanosecond, + _ => TimeUnit::Nanosecond, + } + } + // Default to Nanosecond which is common if not present + None => TimeUnit::Nanosecond, + }; + (unit, tz) + } + // Default to Nanosecond and None for tz which is common if not present + None => (TimeUnit::Nanosecond, None), + }; + DataType::Timestamp(unit, tz) + } _ => { - panic!("stop here"); - // sql_type.to_arrow() + // panic!("stop here"); + sql_type.to_arrow() } }; @@ -161,6 +192,26 @@ pub enum SqlTypeName { } impl SqlTypeName { + pub fn to_arrow(&self) -> DataType { + match self { + SqlTypeName::NULL => DataType::Null, + SqlTypeName::BOOLEAN => DataType::Boolean, + SqlTypeName::TINYINT => DataType::Int8, + SqlTypeName::SMALLINT => DataType::Int16, + SqlTypeName::INTEGER => DataType::Int32, + SqlTypeName::BIGINT => DataType::Int64, + SqlTypeName::REAL => DataType::Float16, + SqlTypeName::FLOAT => DataType::Float32, + SqlTypeName::DOUBLE => DataType::Float64, + SqlTypeName::DATE => DataType::Date64, + SqlTypeName::VARCHAR => DataType::Utf8, + _ => { + println!("Type: {:?}", self); + todo!(); + } + } + } + pub fn from_arrow(data_type: &DataType) -> Self { match data_type { DataType::Null => SqlTypeName::NULL, diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index da35e1e2b..dede1f8e4 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -62,23 +62,20 @@ # Default mapping between SQL types and python types # for data frames _SQL_TO_PYTHON_FRAMES = { - "DOUBLE": np.float64, - "FLOAT": np.float32, - "DECIMAL": np.float64, - "BIGINT": pd.Int64Dtype(), - "INTEGER": pd.Int32Dtype(), - "INT": pd.Int32Dtype(), # Although not in the standard, makes compatibility easier - "SMALLINT": pd.Int16Dtype(), - "TINYINT": pd.Int8Dtype(), - "BOOLEAN": pd.BooleanDtype(), - "VARCHAR": pd.StringDtype(), - "CHAR": pd.StringDtype(), - "STRING": pd.StringDtype(), # Although not in the standard, makes compatibility easier - "DATE": np.dtype( + "SqlTypeName.DOUBLE": np.float64, + "SqlTypeName.FLOAT": np.float32, + "SqlTypeName.DECIMAL": np.float64, + "SqlTypeName.BIGINT": pd.Int64Dtype(), + "SqlTypeName.INTEGER": pd.Int32Dtype(), + "SqlTypeName.SMALLINT": pd.Int16Dtype(), + "SqlTypeName.TINYINT": pd.Int8Dtype(), + "SqlTypeName.BOOLEAN": pd.BooleanDtype(), + "SqlTypeName.VARCHAR": pd.StringDtype(), + "SqlTypeName.DATE": np.dtype( " Any: return python_type(literal_value) -def sql_to_python_type(sql_type: str) -> type: +def sql_to_python_type(sql_type: "SqlTypeName") -> type: """Turn an SQL type into a dataframe dtype""" - # Ex: Rust SqlTypeName Enum str value 'SqlTypeName.DOUBLE' - if sql_type.find(".") != -1: - sql_type = sql_type.split(".")[1] + # # Ex: Rust SqlTypeName Enum str value 'SqlTypeName.DOUBLE' + # if sql_type.find(".") != -1: + # sql_type = sql_type.split(".")[1] - print(f"sql_type: {sql_type}") - if sql_type.startswith("CHAR(") or sql_type.startswith("VARCHAR("): + print(f"sql_type: {sql_type} type: {type(sql_type)}") + print(f"equal?: {SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE == sql_type}") + if sql_type == SqlTypeName.VARCHAR or sql_type == SqlTypeName.CHAR: return pd.StringDtype() - elif sql_type.startswith("INTERVAL"): - return np.dtype(" Date: Sat, 23 Apr 2022 19:03:22 -0400 Subject: [PATCH 18/61] linter changes, why did that work on my local pre-commit?? --- dask_sql/context.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/dask_sql/context.py b/dask_sql/context.py index 47670cf74..56bf73bd6 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -10,13 +10,7 @@ from dask.base import optimize from dask.distributed import Client -from dask_planner.rust import ( - DaskSchema, - DaskSQLContext, - DaskTable, - DaskTypeMap, - DFParsingException, -) +from dask_planner.rust import DaskSchema, DaskSQLContext, DaskTable, DFParsingException try: import dask_cuda # noqa: F401 From 652205e500dff97226daf22d494aa54f7edf593c Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Sat, 23 Apr 2022 19:10:44 -0400 Subject: [PATCH 19/61] linter changes, why did that work on my local pre-commit?? --- tests/integration/test_groupby.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/test_groupby.py b/tests/integration/test_groupby.py index b63435b48..e69baff8b 100644 --- a/tests/integration/test_groupby.py +++ b/tests/integration/test_groupby.py @@ -368,6 +368,7 @@ def test_stats_aggregation(c, timeseries_df): ) +@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize( "input_table", [ From 5127f87b525e55231ec1f2c08a68e07933098a16 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Sat, 23 Apr 2022 21:45:46 -0400 Subject: [PATCH 20/61] Convert final strs to SqlTypeName Enum --- dask_planner/src/sql/types.rs | 3 +- dask_sql/mappings.py | 53 ++++++++++++--------------- dask_sql/physical/rex/core/literal.py | 29 ++++++++------- tests/integration/test_groupby.py | 1 + tests/unit/test_mapping.py | 22 ++++++----- 5 files changed, 54 insertions(+), 54 deletions(-) diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index f439f82e9..70b2ed43d 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -50,7 +50,7 @@ impl DaskTypeMap { #[new] #[args(sql_type, py_kwargs = "**")] fn new(sql_type: SqlTypeName, py_kwargs: Option<&PyDict>) -> Self { - println!("sql_type={:?} - py_kwargs={:?}", sql_type, py_kwargs); + // println!("sql_type={:?} - py_kwargs={:?}", sql_type, py_kwargs); let d_type: DataType = match sql_type { SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE => { @@ -158,6 +158,7 @@ pub enum SqlTypeName { FLOAT, GEOMETRY, INTEGER, + INTERVAL, INTERVAL_DAY, INTERVAL_DAY_HOUR, INTERVAL_DAY_MINUTE, diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index dede1f8e4..066928af1 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -45,18 +45,18 @@ # Default mapping between SQL types and python types # for values _SQL_TO_PYTHON_SCALARS = { - "DOUBLE": np.float64, - "FLOAT": np.float32, - "DECIMAL": np.float32, - "BIGINT": np.int64, - "INTEGER": np.int32, - "SMALLINT": np.int16, - "TINYINT": np.int8, - "BOOLEAN": np.bool8, - "VARCHAR": str, - "CHAR": str, - "NULL": type(None), - "SYMBOL": lambda x: x, # SYMBOL is a special type used for e.g. flags etc. We just keep it + "SqlTypeName.DOUBLE": np.float64, + "SqlTypeName.FLOAT": np.float32, + "SqlTypeName.DECIMAL": np.float32, + "SqlTypeName.BIGINT": np.int64, + "SqlTypeName.INTEGER": np.int32, + "SqlTypeName.SMALLINT": np.int16, + "SqlTypeName.TINYINT": np.int8, + "SqlTypeName.BOOLEAN": np.bool8, + "SqlTypeName.VARCHAR": str, + "SqlTypeName.CHAR": str, + "SqlTypeName.NULL": type(None), + "SqlTypeName.SYMBOL": lambda x: x, # SYMBOL is a special type used for e.g. flags etc. We just keep it } # Default mapping between SQL types and python types @@ -99,7 +99,7 @@ def python_to_sql_type(python_type) -> "DaskTypeMap": ) -def sql_to_python_value(sql_type: str, literal_value: Any) -> Any: +def sql_to_python_value(sql_type: "SqlTypeName", literal_value: Any) -> Any: """Mapping between SQL and python values (of correct type).""" # In most of the cases, we turn the value first into a string. # That might not be the most efficient thing to do, @@ -110,14 +110,8 @@ def sql_to_python_value(sql_type: str, literal_value: Any) -> Any: logger.debug( f"sql_to_python_value -> sql_type: {sql_type} literal_value: {literal_value}" ) - sql_type = sql_type.upper() - if ( - sql_type.startswith("CHAR(") - or sql_type.startswith("VARCHAR(") - or sql_type == "VARCHAR" - or sql_type == "CHAR" - ): + if sql_type == SqlTypeName.CHAR or sql_type == SqlTypeName.VARCHAR: # Some varchars contain an additional encoding # in the format _ENCODING'string' literal_value = str(literal_value) @@ -129,10 +123,10 @@ def sql_to_python_value(sql_type: str, literal_value: Any) -> Any: return literal_value - elif sql_type.startswith("INTERVAL"): + elif sql_type == SqlTypeName.INTERVAL: # check for finer granular interval types, e.g., INTERVAL MONTH, INTERVAL YEAR try: - interval_type = sql_type.split()[1].lower() + interval_type = str(sql_type).split()[1].lower() if interval_type in {"year", "quarter", "month"}: # if sql_type is INTERVAL YEAR, Calcite will covert to months @@ -149,13 +143,13 @@ def sql_to_python_value(sql_type: str, literal_value: Any) -> Any: # Issue: if sql_type is INTERVAL MICROSECOND, and value <= 1000, literal_value will be rounded to 0 return timedelta(milliseconds=float(str(literal_value))) - elif sql_type == "BOOLEAN": + elif sql_type == SqlTypeName.BOOLEAN: return bool(literal_value) elif ( - sql_type.startswith("TIMESTAMP(") - or sql_type.startswith("TIME(") - or sql_type == "DATE" + sql_type == SqlTypeName.TIMESTAMP + or sql_type == SqlTypeName.TIME + or sql_type == SqlTypeName.DATE ): if str(literal_value) == "None": # NULL time @@ -166,16 +160,16 @@ def sql_to_python_value(sql_type: str, literal_value: Any) -> Any: dt = np.datetime64(literal_value.getTimeInMillis(), "ms") - if sql_type == "DATE": + if sql_type == SqlTypeName.DATE: return dt.astype(" bool: TODO: nullability is not checked so far. """ - print(f"similar_type: {lhs} - {rhs}") pdt = pd.api.types is_uint = pdt.is_unsigned_integer_dtype is_sint = pdt.is_signed_integer_dtype diff --git a/dask_sql/physical/rex/core/literal.py b/dask_sql/physical/rex/core/literal.py index b4eb886d1..6f1844de9 100644 --- a/dask_sql/physical/rex/core/literal.py +++ b/dask_sql/physical/rex/core/literal.py @@ -3,6 +3,7 @@ import dask.dataframe as dd +from dask_planner.rust import SqlTypeName from dask_sql.datacontainer import DataContainer from dask_sql.mappings import sql_to_python_value from dask_sql.physical.rex.base import BaseRexPlugin @@ -102,46 +103,46 @@ def convert( # Call the Rust function to get the actual value and convert the Rust # type name back to a SQL type if literal_type == "Boolean": - literal_type = "BOOLEAN" + literal_type = SqlTypeName.BOOLEAN literal_value = rex.getBoolValue() elif literal_type == "Float32": - literal_type = "FLOAT" + literal_type = SqlTypeName.FLOAT literal_value = rex.getFloat32Value() elif literal_type == "Float64": - literal_type = "DOUBLE" + literal_type = SqlTypeName.DOUBLE literal_value = rex.getFloat64Value() elif literal_type == "UInt8": - literal_type = "TINYINT" + literal_type = SqlTypeName.TINYINT literal_value = rex.getUInt8Value() elif literal_type == "UInt16": - literal_type = "SMALLINT" + literal_type = SqlTypeName.SMALLINT literal_value = rex.getUInt16Value() elif literal_type == "UInt32": - literal_type = "INTEGER" + literal_type = SqlTypeName.INTEGER literal_value = rex.getUInt32Value() elif literal_type == "UInt64": - literal_type = "BIGINT" + literal_type = SqlTypeName.BIGINT literal_value = rex.getUInt64Value() elif literal_type == "Int8": - literal_type = "TINYINT" + literal_type = SqlTypeName.TINYINT literal_value = rex.getInt8Value() elif literal_type == "Int16": - literal_type = "SMALLINT" + literal_type = SqlTypeName.SMALLINT literal_value = rex.getInt16Value() elif literal_type == "Int32": - literal_type = "INTEGER" + literal_type = SqlTypeName.INTEGER literal_value = rex.getInt32Value() elif literal_type == "Int64": - literal_type = "BIGINT" + literal_type = SqlTypeName.BIGINT literal_value = rex.getInt64Value() elif literal_type == "Utf8": - literal_type = "VARCHAR" + literal_type = SqlTypeName.VARCHAR literal_value = rex.getStringValue() elif literal_type == "Date32": - literal_type = "Date" + literal_type = SqlTypeName.DATE literal_value = rex.getDateValue() elif literal_type == "Date64": - literal_type = "Date" + literal_type = SqlTypeName.DATE literal_value = rex.getDateValue() else: raise RuntimeError("Failed to determine DataFusion Type in literal.py") diff --git a/tests/integration/test_groupby.py b/tests/integration/test_groupby.py index e69baff8b..109074692 100644 --- a/tests/integration/test_groupby.py +++ b/tests/integration/test_groupby.py @@ -22,6 +22,7 @@ def timeseries_df(c): return None +@pytest.mark.skip(reason="WIP DataFusion") def test_group_by(c): return_df = c.sql( """ diff --git a/tests/unit/test_mapping.py b/tests/unit/test_mapping.py index 692b22843..dc62751cd 100644 --- a/tests/unit/test_mapping.py +++ b/tests/unit/test_mapping.py @@ -3,28 +3,32 @@ import numpy as np import pandas as pd +from dask_planner.rust import SqlTypeName from dask_sql.mappings import python_to_sql_type, similar_type, sql_to_python_value def test_python_to_sql(): - assert str(python_to_sql_type(np.dtype("int32"))) == "INTEGER" - assert str(python_to_sql_type(np.dtype(">M8[ns]"))) == "TIMESTAMP" + assert python_to_sql_type(np.dtype("int32")).getSqlType() == SqlTypeName.INTEGER + assert python_to_sql_type(np.dtype(">M8[ns]")).getSqlType() == SqlTypeName.TIMESTAMP + thing = python_to_sql_type(pd.DatetimeTZDtype(unit="ns", tz="UTC")).getSqlType() assert ( - str(python_to_sql_type(pd.DatetimeTZDtype(unit="ns", tz="UTC"))) - == "timestamp[ms, tz=UTC]" + python_to_sql_type(pd.DatetimeTZDtype(unit="ns", tz="UTC")).getSqlType() + == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE ) def test_sql_to_python(): - assert sql_to_python_value("CHAR(5)", "test 123") == "test 123" - assert type(sql_to_python_value("BIGINT", 653)) == np.int64 - assert sql_to_python_value("BIGINT", 653) == 653 - assert sql_to_python_value("INTERVAL", 4) == timedelta(milliseconds=4) + assert sql_to_python_value(SqlTypeName.VARCHAR, "test 123") == "test 123" + assert type(sql_to_python_value(SqlTypeName.BIGINT, 653)) == np.int64 + assert sql_to_python_value(SqlTypeName.BIGINT, 653) == 653 + assert sql_to_python_value(SqlTypeName.INTERVAL, 4) == timedelta(microseconds=4000) def test_python_to_sql_to_python(): assert ( - type(sql_to_python_value(str(python_to_sql_type(np.dtype("int64"))), 54)) + type( + sql_to_python_value(python_to_sql_type(np.dtype("int64")).getSqlType(), 54) + ) == np.int64 ) From cf568dcd157a2048fbf3ed8d21d0b8d4f6707b3d Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Sat, 23 Apr 2022 22:23:35 -0400 Subject: [PATCH 21/61] removed a few print statements --- dask_sql/mappings.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 066928af1..0768732ba 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -189,12 +189,6 @@ def sql_to_python_value(sql_type: "SqlTypeName", literal_value: Any) -> Any: def sql_to_python_type(sql_type: "SqlTypeName") -> type: """Turn an SQL type into a dataframe dtype""" - # # Ex: Rust SqlTypeName Enum str value 'SqlTypeName.DOUBLE' - # if sql_type.find(".") != -1: - # sql_type = sql_type.split(".")[1] - - print(f"sql_type: {sql_type} type: {type(sql_type)}") - print(f"equal?: {SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE == sql_type}") if sql_type == SqlTypeName.VARCHAR or sql_type == SqlTypeName.CHAR: return pd.StringDtype() elif sql_type == SqlTypeName.TIME or sql_type == SqlTypeName.TIMESTAMP: From 4fb640ef3b8883590887979c0eddbe240a65f2d4 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Sun, 24 Apr 2022 16:00:10 -0400 Subject: [PATCH 22/61] commit to share with colleague --- dask_planner/src/expression.rs | 11 +- dask_planner/src/sql.rs | 2 +- dask_planner/src/sql/logical/projection.rs | 9 + dask_planner/src/sql/types.rs | 1 - dask_sql/mappings.py | 6 +- dask_sql/physical/rel/base.py | 3 - dask_sql/physical/rel/logical/project.py | 2 - tests/integration/test_select.py | 353 +++++++++++---------- 8 files changed, 195 insertions(+), 192 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 4f589f1d1..a9f8edf3c 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -7,7 +7,7 @@ use std::convert::{From, Into}; use datafusion::error::DataFusionError; use datafusion::arrow::datatypes::DataType; -use datafusion_expr::{col, lit, BuiltinScalarFunction, Expr}; +use datafusion_expr::{lit, BuiltinScalarFunction, Expr}; use datafusion::scalar::ScalarValue; @@ -70,13 +70,9 @@ impl PyExpr { } } - fn _column_name(&self, mut plan: LogicalPlan) -> String { + fn _column_name(&self, plan: LogicalPlan) -> String { match &self.expr { Expr::Alias(expr, name) => { - println!("Alias encountered with name: {:?}", name); - // let reference: Expr = *expr.as_ref(); - // let plan: logical::PyLogicalPlan = reference.input().clone().into(); - // Only certain LogicalPlan variants are valid in this nested Alias scenario so we // extract the valid ones and error on the invalid ones match expr.as_ref() { @@ -116,9 +112,8 @@ impl PyExpr { } _ => name.clone(), } - } + }, _ => { - println!("Encountered a non Expr::Column instance"); name.clone() } } diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 05650bbb1..2cbc6feef 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -158,7 +158,7 @@ impl DaskSQLContext { statement: statement::PyStatement, ) -> PyResult { let planner = SqlToRel::new(self); - + println!("Statement: {:?}", statement.statement); match planner.statement_to_plan(statement.statement) { Ok(k) => { println!("\nLogicalPlan: {:?}\n\n", k); diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index 3d2eccdd8..d42bc58f6 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -39,9 +39,18 @@ impl PyProjection { _ => unimplemented!("projection.rs column_name is unimplemented for LogicalPlan variant: {:?}", self.projection.input), } } + Expr::Cast { expr, data_type:_ } => { + let ex_type: Expr = *expr.clone(); + val = self.column_name(ex_type.into()).unwrap(); + println!("Setting col name to: {:?}", val); + }, _ => panic!("not supported: {:?}", expr), }, Expr::Column(col) => val = col.name.clone(), + Expr::Cast { expr, data_type:_ } => { + let ex_type: Expr = *expr; + val = self.column_name(ex_type.into()).unwrap() + }, _ => { panic!( "column_name is unimplemented for Expr variant: {:?}", diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index 70b2ed43d..7e6047f7c 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -4,7 +4,6 @@ pub mod rel_data_type; pub mod rel_data_type_field; use pyo3::prelude::*; -use pyo3::types::PyAny; use pyo3::types::PyDict; #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 0768732ba..9f8b12993 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -261,7 +261,7 @@ def cast_column_type( """ current_type = df[column_name].dtype - print(f"Column {column_name} has type {current_type}, expecting {expected_type}...") + # print(f"Column {column_name} has type {current_type}, expecting {expected_type}...") casted_column = cast_column_to_type(df[column_name], expected_type) @@ -290,4 +290,6 @@ def cast_column_to_type(col: dd.Series, expected_type: str): col = da.trunc(col.fillna(value=np.NaN)) print(f"Need to cast from {current_type} to {expected_type}") - return col.astype(expected_type) + col = col.astype(expected_type) + print(f"col type: {col.dtype}") + return col diff --git a/dask_sql/physical/rel/base.py b/dask_sql/physical/rel/base.py index 996ec04e9..aae628f18 100644 --- a/dask_sql/physical/rel/base.py +++ b/dask_sql/physical/rel/base.py @@ -105,9 +105,6 @@ def fix_dtype_to_row_type(dc: DataContainer, row_type: "RelDataType"): expected_type = sql_to_python_type(field_type.getSqlType()) df_field_name = cc.get_backend_by_frontend_name(field_name) - print( - f"Before cast df_field_name: {df_field_name}, expected_type: {expected_type}" - ) df = cast_column_type(df, df_field_name, expected_type) return DataContainer(df, dc.column_container) diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index c33054442..0226acc1e 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -30,8 +30,6 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai df = dc.df cc = dc.column_container - print(f"Before Project: {df.head(10)}") - # Collect all (new) columns proj = rel.projection() named_projects = proj.getNamedProjects() diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index e4b357384..950002175 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -6,153 +6,153 @@ from tests.utils import assert_eq -def test_select(c, df): - result_df = c.sql("SELECT * FROM df") +# def test_select(c, df): +# result_df = c.sql("SELECT * FROM df") + +# assert_eq(result_df, df) - assert_eq(result_df, df) + +# @pytest.mark.skip(reason="WIP DataFusion") +# def test_select_alias(c, df): +# result_df = c.sql("SELECT a as b, b as a FROM df") + +# expected_df = pd.DataFrame(index=df.index) +# expected_df["b"] = df.a +# expected_df["a"] = df.b + +# assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) -@pytest.mark.skip(reason="WIP DataFusion") -def test_select_alias(c, df): - result_df = c.sql("SELECT a as b, b as a FROM df") +# def test_select_column(c, df): +# result_df = c.sql("SELECT a FROM df") - expected_df = pd.DataFrame(index=df.index) - expected_df["b"] = df.a - expected_df["a"] = df.b +# assert_eq(result_df, df[["a"]]) - assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) - -def test_select_column(c, df): - result_df = c.sql("SELECT a FROM df") - - assert_eq(result_df, df[["a"]]) - - -def test_select_different_types(c): - expected_df = pd.DataFrame( - { - "date": pd.to_datetime(["2022-01-21 17:34", "2022-01-21", "17:34", pd.NaT]), - "string": ["this is a test", "another test", "äölüć", ""], - "integer": [1, 2, -4, 5], - "float": [-1.1, np.NaN, pd.NA, np.sqrt(2)], - } - ) - c.create_table("df", expected_df) - result_df = c.sql( - """ - SELECT * - FROM df - """ - ) - - assert_eq(result_df, expected_df) - - -@pytest.mark.skip(reason="WIP DataFusion") -def test_select_expr(c, df): - result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") - result_df = result_df - - expected_df = pd.DataFrame( - { - "a": df["a"] + 1, - "bla": df["b"], - '"df"."a" - 1': df["a"] - 1, - } - ) - assert_eq(result_df, expected_df) - - -@pytest.mark.skip( - reason="WIP DataFusion, subquery - https://github.com/apache/arrow-datafusion/issues/2237" -) -def test_select_of_select(c, df): - result_df = c.sql( - """ - SELECT 2*c AS e, d - 1 AS f - FROM - ( - SELECT a - 1 AS c, 2*b AS d - FROM df - ) AS "inner" - """ - ) - - expected_df = pd.DataFrame({"e": 2 * (df["a"] - 1), "f": 2 * df["b"] - 1}) - assert_eq(result_df, expected_df) - - -@pytest.mark.skip(reason="WIP DataFusion") -def test_select_of_select_with_casing(c, df): - result_df = c.sql( - """ - SELECT AAA, aaa, aAa - FROM - ( - SELECT a - 1 AS aAa, 2*b AS aaa, a + b AS AAA - FROM df - ) AS "inner" - """ - ) - - expected_df = pd.DataFrame( - {"AAA": df["a"] + df["b"], "aaa": 2 * df["b"], "aAa": df["a"] - 1} - ) - - assert_eq(result_df, expected_df) - - -def test_wrong_input(c): - with pytest.raises(ParsingException): - c.sql("""SELECT x FROM df""") - - with pytest.raises(ParsingException): - c.sql("""SELECT x FROM df""") - - -def test_timezones(c, datetime_table): - result_df = c.sql( - """ - SELECT * FROM datetime_table - """ - ) - - print(f"Expected DF: \n{datetime_table.head(10)}\n") - print(f"\nResult DF: \n{result_df.head(10)}") - - assert_eq(result_df, datetime_table) - - -@pytest.mark.skip(reason="WIP DataFusion") -@pytest.mark.parametrize( - "input_table", - [ - "long_table", - pytest.param("gpu_long_table", marks=pytest.mark.gpu), - ], -) -@pytest.mark.parametrize( - "limit,offset", - [(101, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)], -) -def test_limit(c, input_table, limit, offset, request): - long_table = request.getfixturevalue(input_table) - - if not limit: - query = f"SELECT * FROM long_table OFFSET {offset}" - else: - query = f"SELECT * FROM long_table LIMIT {limit} OFFSET {offset}" - - assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) - - -@pytest.mark.skip(reason="WIP DataFusion") +# def test_select_different_types(c): +# expected_df = pd.DataFrame( +# { +# "date": pd.to_datetime(["2022-01-21 17:34", "2022-01-21", "17:34", pd.NaT]), +# "string": ["this is a test", "another test", "äölüć", ""], +# "integer": [1, 2, -4, 5], +# "float": [-1.1, np.NaN, pd.NA, np.sqrt(2)], +# } +# ) +# c.create_table("df", expected_df) +# result_df = c.sql( +# """ +# SELECT * +# FROM df +# """ +# ) + +# assert_eq(result_df, expected_df) + + +# @pytest.mark.skip(reason="WIP DataFusion") +# def test_select_expr(c, df): +# result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") +# result_df = result_df + +# expected_df = pd.DataFrame( +# { +# "a": df["a"] + 1, +# "bla": df["b"], +# '"df"."a" - 1': df["a"] - 1, +# } +# ) +# assert_eq(result_df, expected_df) + + +# @pytest.mark.skip( +# reason="WIP DataFusion, subquery - https://github.com/apache/arrow-datafusion/issues/2237" +# ) +# def test_select_of_select(c, df): +# result_df = c.sql( +# """ +# SELECT 2*c AS e, d - 1 AS f +# FROM +# ( +# SELECT a - 1 AS c, 2*b AS d +# FROM df +# ) AS "inner" +# """ +# ) + +# expected_df = pd.DataFrame({"e": 2 * (df["a"] - 1), "f": 2 * df["b"] - 1}) +# assert_eq(result_df, expected_df) + + +# @pytest.mark.skip(reason="WIP DataFusion") +# def test_select_of_select_with_casing(c, df): +# result_df = c.sql( +# """ +# SELECT AAA, aaa, aAa +# FROM +# ( +# SELECT a - 1 AS aAa, 2*b AS aaa, a + b AS AAA +# FROM df +# ) AS "inner" +# """ +# ) + +# expected_df = pd.DataFrame( +# {"AAA": df["a"] + df["b"], "aaa": 2 * df["b"], "aAa": df["a"] - 1} +# ) + +# assert_eq(result_df, expected_df) + + +# def test_wrong_input(c): +# with pytest.raises(ParsingException): +# c.sql("""SELECT x FROM df""") + +# with pytest.raises(ParsingException): +# c.sql("""SELECT x FROM df""") + + +# def test_timezones(c, datetime_table): +# result_df = c.sql( +# """ +# SELECT * FROM datetime_table +# """ +# ) + +# print(f"Expected DF: \n{datetime_table.head(10)}\n") +# print(f"\nResult DF: \n{result_df.head(10)}") + +# assert_eq(result_df, datetime_table) + + +# @pytest.mark.skip(reason="WIP DataFusion") +# @pytest.mark.parametrize( +# "input_table", +# [ +# "long_table", +# pytest.param("gpu_long_table", marks=pytest.mark.gpu), +# ], +# ) +# @pytest.mark.parametrize( +# "limit,offset", +# [(101, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)], +# ) +# def test_limit(c, input_table, limit, offset, request): +# long_table = request.getfixturevalue(input_table) + +# if not limit: +# query = f"SELECT * FROM long_table OFFSET {offset}" +# else: +# query = f"SELECT * FROM long_table LIMIT {limit} OFFSET {offset}" + +# assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) + + +# @pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize( "input_table", [ "datetime_table", - pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), + # pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), ], ) def test_date_casting(c, input_table, request): @@ -178,45 +178,48 @@ def test_date_casting(c, input_table, request): expected_df["utc_timezone"].astype(" Date: Sun, 24 Apr 2022 20:19:49 -0400 Subject: [PATCH 23/61] updates --- dask_planner/src/expression.rs | 4 +--- dask_planner/src/sql/logical/projection.rs | 8 ++++---- tests/integration/test_select.py | 7 +++---- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index a9f8edf3c..c682c3e1d 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -112,10 +112,8 @@ impl PyExpr { } _ => name.clone(), } - }, - _ => { - name.clone() } + _ => name.clone(), } } Expr::Column(column) => column.name.clone(), diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index d42bc58f6..d518a14b2 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -39,18 +39,18 @@ impl PyProjection { _ => unimplemented!("projection.rs column_name is unimplemented for LogicalPlan variant: {:?}", self.projection.input), } } - Expr::Cast { expr, data_type:_ } => { + Expr::Cast { expr, data_type: _ } => { let ex_type: Expr = *expr.clone(); val = self.column_name(ex_type.into()).unwrap(); println!("Setting col name to: {:?}", val); - }, + } _ => panic!("not supported: {:?}", expr), }, Expr::Column(col) => val = col.name.clone(), - Expr::Cast { expr, data_type:_ } => { + Expr::Cast { expr, data_type: _ } => { let ex_type: Expr = *expr; val = self.column_name(ex_type.into()).unwrap() - }, + } _ => { panic!( "column_name is unimplemented for Expr variant: {:?}", diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index 950002175..b794e0c9f 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -1,11 +1,10 @@ -import numpy as np -import pandas as pd +# import numpy as np +# import pandas as pd import pytest -from dask_sql.utils import ParsingException +# from dask_sql.utils import ParsingException from tests.utils import assert_eq - # def test_select(c, df): # result_df = c.sql("SELECT * FROM df") From f5e24fe75c966325a55aa418474a147d05e8dcc9 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 25 Apr 2022 08:47:26 -0400 Subject: [PATCH 24/61] checkpoint --- dask_planner/src/expression.rs | 7 ++ dask_planner/src/sql/logical/projection.rs | 7 ++ tests/integration/test_select.py | 90 +++++++++++----------- 3 files changed, 59 insertions(+), 45 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index c682c3e1d..8bb3b68c5 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -75,6 +75,7 @@ impl PyExpr { Expr::Alias(expr, name) => { // Only certain LogicalPlan variants are valid in this nested Alias scenario so we // extract the valid ones and error on the invalid ones + println!("Alias Expr: {:?} - Alias Name: {:?}", expr, name); match expr.as_ref() { Expr::Column(col) => { // First we must iterate the current node before getting its input @@ -113,6 +114,12 @@ impl PyExpr { _ => name.clone(), } } + // Expr::Case { expr, when_then_expr, else_expr } => { + // println!("expr: {:?}", expr); + // println!("when_then_expr: {:?}", when_then_expr); + // println!("else_expr: {:?}", else_expr); + // panic!("Case WHEN BABY!!!"); + // }, _ => name.clone(), } } diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index d518a14b2..9caee8e87 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -44,6 +44,13 @@ impl PyProjection { val = self.column_name(ex_type.into()).unwrap(); println!("Setting col name to: {:?}", val); } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + println!("Case WHEN BABY!!!"); + } _ => panic!("not supported: {:?}", expr), }, Expr::Column(col) => val = col.name.clone(), diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index b794e0c9f..d67d4b253 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -1,5 +1,5 @@ # import numpy as np -# import pandas as pd +import pandas as pd import pytest # from dask_sql.utils import ParsingException @@ -146,41 +146,41 @@ # assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) -# @pytest.mark.skip(reason="WIP DataFusion") -@pytest.mark.parametrize( - "input_table", - [ - "datetime_table", - # pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), - ], -) -def test_date_casting(c, input_table, request): - datetime_table = request.getfixturevalue(input_table) - result_df = c.sql( - f""" - SELECT - CAST(timezone AS DATE) AS timezone, - CAST(no_timezone AS DATE) AS no_timezone, - CAST(utc_timezone AS DATE) AS utc_timezone - FROM {input_table} - """ - ) +# # @pytest.mark.skip(reason="WIP DataFusion") +# @pytest.mark.parametrize( +# "input_table", +# [ +# "datetime_table", +# # pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), +# ], +# ) +# def test_date_casting(c, input_table, request): +# datetime_table = request.getfixturevalue(input_table) +# result_df = c.sql( +# f""" +# SELECT +# CAST(timezone AS DATE) AS timezone, +# CAST(no_timezone AS DATE) AS no_timezone, +# CAST(utc_timezone AS DATE) AS utc_timezone +# FROM {input_table} +# """ +# ) - expected_df = datetime_table - expected_df["timezone"] = ( - expected_df["timezone"].astype(" Date: Mon, 25 Apr 2022 09:02:38 -0400 Subject: [PATCH 25/61] Temporarily disable conda run_test.py script since it uses features not yet implemented --- continuous_integration/recipe/run_test.py | 30 ++++++++++++----------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/continuous_integration/recipe/run_test.py b/continuous_integration/recipe/run_test.py index 0ca97261b..01616d1db 100644 --- a/continuous_integration/recipe/run_test.py +++ b/continuous_integration/recipe/run_test.py @@ -13,19 +13,21 @@ df = pd.DataFrame({"name": ["Alice", "Bob", "Chris"] * 100, "x": list(range(300))}) ddf = dd.from_pandas(df, npartitions=10) -c.create_table("my_data", ddf) -got = c.sql( - """ - SELECT - my_data.name, - SUM(my_data.x) AS "S" - FROM - my_data - GROUP BY - my_data.name -""" -) -expect = pd.DataFrame({"name": ["Alice", "Bob", "Chris"], "S": [14850, 14950, 15050]}) +# This needs to be temprarily disabled since this query requires features that are not yet implemented +# c.create_table("my_data", ddf) + +# got = c.sql( +# """ +# SELECT +# my_data.name, +# SUM(my_data.x) AS "S" +# FROM +# my_data +# GROUP BY +# my_data.name +# """ +# ) +# expect = pd.DataFrame({"name": ["Alice", "Bob", "Chris"], "S": [14850, 14950, 15050]}) -dd.assert_eq(got, expect) +# dd.assert_eq(got, expect) From 46dfb0a423470f40a6b0e8fd0bf423448fccc0ce Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 25 Apr 2022 09:18:28 -0400 Subject: [PATCH 26/61] formatting after upstream merge --- dask_sql/mappings.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 9f8b12993..1d2c56ec3 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -261,7 +261,9 @@ def cast_column_type( """ current_type = df[column_name].dtype - # print(f"Column {column_name} has type {current_type}, expecting {expected_type}...") + logger.debug( + f"Column {column_name} has type {current_type}, expecting {expected_type}..." + ) casted_column = cast_column_to_type(df[column_name], expected_type) From fa71674003b9faa4e72bdf46459525411ee442a0 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 25 Apr 2022 09:31:33 -0400 Subject: [PATCH 27/61] expose fromString method for SqlTypeName to use Enums instead of strings for type checking --- dask_planner/src/sql/types.rs | 53 ++++++++++++++++-------------- dask_sql/input_utils/hive.py | 4 ++- dask_sql/physical/rex/core/call.py | 7 ++-- 3 files changed, 36 insertions(+), 28 deletions(-) diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index 7e6047f7c..76c4699b9 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -248,34 +248,37 @@ impl SqlTypeName { _ => todo!(), } } +} +#[pymethods] +impl SqlTypeName { + #[pyo3(name = "fromString")] + #[staticmethod] pub fn from_string(input_type: &str) -> Self { match input_type { - "SqlTypeName.NULL" => SqlTypeName::NULL, - "SqlTypeName.BOOLEAN" => SqlTypeName::BOOLEAN, - "SqlTypeName.TINYINT" => SqlTypeName::TINYINT, - "SqlTypeName.SMALLINT" => SqlTypeName::SMALLINT, - "SqlTypeName.INTEGER" => SqlTypeName::INTEGER, - "SqlTypeName.BIGINT" => SqlTypeName::BIGINT, - "SqlTypeName.REAL" => SqlTypeName::REAL, - "SqlTypeName.FLOAT" => SqlTypeName::FLOAT, - "SqlTypeName.DOUBLE" => SqlTypeName::DOUBLE, - "SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE" => { - SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE - } - "SqlTypeName.TIMESTAMP" => SqlTypeName::TIMESTAMP, - "SqlTypeName.DATE" => SqlTypeName::DATE, - "SqlTypeName.INTERVAL_DAY" => SqlTypeName::INTERVAL_DAY, - "SqlTypeName.INTERVAL_YEAR_MONTH" => SqlTypeName::INTERVAL_YEAR_MONTH, - "SqlTypeName.INTERVAL_MONTH" => SqlTypeName::INTERVAL_MONTH, - "SqlTypeName.BINARY" => SqlTypeName::BINARY, - "SqlTypeName.VARBINARY" => SqlTypeName::VARBINARY, - "SqlTypeName.CHAR" => SqlTypeName::CHAR, - "SqlTypeName.VARCHAR" => SqlTypeName::VARCHAR, - "SqlTypeName.STRUCTURED" => SqlTypeName::STRUCTURED, - "SqlTypeName.DECIMAL" => SqlTypeName::DECIMAL, - "SqlTypeName.MAP" => SqlTypeName::MAP, - _ => todo!(), + "NULL" => SqlTypeName::NULL, + "BOOLEAN" => SqlTypeName::BOOLEAN, + "TINYINT" => SqlTypeName::TINYINT, + "SMALLINT" => SqlTypeName::SMALLINT, + "INTEGER" => SqlTypeName::INTEGER, + "BIGINT" => SqlTypeName::BIGINT, + "REAL" => SqlTypeName::REAL, + "FLOAT" => SqlTypeName::FLOAT, + "DOUBLE" => SqlTypeName::DOUBLE, + "TIMESTAMP_WITH_LOCAL_TIME_ZONE" => SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE, + "TIMESTAMP" => SqlTypeName::TIMESTAMP, + "DATE" => SqlTypeName::DATE, + "INTERVAL_DAY" => SqlTypeName::INTERVAL_DAY, + "INTERVAL_YEAR_MONTH" => SqlTypeName::INTERVAL_YEAR_MONTH, + "INTERVAL_MONTH" => SqlTypeName::INTERVAL_MONTH, + "BINARY" => SqlTypeName::BINARY, + "VARBINARY" => SqlTypeName::VARBINARY, + "CHAR" => SqlTypeName::CHAR, + "VARCHAR" => SqlTypeName::VARCHAR, + "STRUCTURED" => SqlTypeName::STRUCTURED, + "DECIMAL" => SqlTypeName::DECIMAL, + "MAP" => SqlTypeName::MAP, + _ => unimplemented!(), } } } diff --git a/dask_sql/input_utils/hive.py b/dask_sql/input_utils/hive.py index 4e1bdde62..a50c167fd 100644 --- a/dask_sql/input_utils/hive.py +++ b/dask_sql/input_utils/hive.py @@ -6,6 +6,8 @@ import dask.dataframe as dd +from dask_planner.rust import SqlTypeName + try: from pyhive import hive except ImportError: # pragma: no cover @@ -65,7 +67,7 @@ def to_dc( # Convert column information column_information = { - col: sql_to_python_type(col_type.upper()) + col: sql_to_python_type(SqlTypeName.fromString(col_type.upper())) for col, col_type in column_information.items() } diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 4ef1d64bf..68c941c30 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -14,6 +14,7 @@ from dask.highlevelgraph import HighLevelGraph from dask.utils import random_state_data +from dask_planner.rust import SqlTypeName from dask_sql.datacontainer import DataContainer from dask_sql.mappings import cast_column_to_type, sql_to_python_type from dask_sql.physical.rex import RexConverter @@ -140,7 +141,7 @@ def div(self, lhs, rhs, rex=None): result = lhs / rhs output_type = str(rex.getType()) - output_type = sql_to_python_type(output_type.upper()) + output_type = sql_to_python_type(SqlTypeName.fromString(output_type.upper())) is_float = pd.api.types.is_float_dtype(output_type) if not is_float: @@ -224,7 +225,9 @@ def cast(self, operand, rex=None) -> SeriesOrScalar: return operand output_type = str(rex.getType()) - python_type = sql_to_python_type(output_type.upper()) + python_type = sql_to_python_type( + output_type=sql_to_python_type(output_type.upper()) + ) return_column = cast_column_to_type(operand, python_type) From f6e86ca154da4c6f4191f96f3b63a4fd30ff5843 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 25 Apr 2022 10:05:50 -0400 Subject: [PATCH 28/61] expanded SqlTypeName from_string() support --- dask_planner/src/sql/types.rs | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index 76c4699b9..26bda2377 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -256,29 +256,55 @@ impl SqlTypeName { #[staticmethod] pub fn from_string(input_type: &str) -> Self { match input_type { + "ANY" => SqlTypeName::ANY, + "ARRAY" => SqlTypeName::ARRAY, "NULL" => SqlTypeName::NULL, "BOOLEAN" => SqlTypeName::BOOLEAN, + "COLUMN_LIST" => SqlTypeName::COLUMN_LIST, + "DISTINCT" => SqlTypeName::DISTINCT, + "CURSOR" => SqlTypeName::CURSOR, "TINYINT" => SqlTypeName::TINYINT, "SMALLINT" => SqlTypeName::SMALLINT, "INTEGER" => SqlTypeName::INTEGER, "BIGINT" => SqlTypeName::BIGINT, "REAL" => SqlTypeName::REAL, "FLOAT" => SqlTypeName::FLOAT, + "GEOMETRY" => SqlTypeName::GEOMETRY, "DOUBLE" => SqlTypeName::DOUBLE, - "TIMESTAMP_WITH_LOCAL_TIME_ZONE" => SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE, + "TIME" => SqlTypeName::TIME, + "TIME_WITH_LOCAL_TIME_ZONE" => SqlTypeName::TIME_WITH_LOCAL_TIME_ZONE, "TIMESTAMP" => SqlTypeName::TIMESTAMP, + "TIMESTAMP_WITH_LOCAL_TIME_ZONE" => SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE, "DATE" => SqlTypeName::DATE, + "INTERVAL" => SqlTypeName::INTERVAL, "INTERVAL_DAY" => SqlTypeName::INTERVAL_DAY, - "INTERVAL_YEAR_MONTH" => SqlTypeName::INTERVAL_YEAR_MONTH, + "INTERVAL_DAY_HOUR" => SqlTypeName::INTERVAL_DAY_HOUR, + "INTERVAL_DAY_MINUTE" => SqlTypeName::INTERVAL_DAY_MINUTE, + "INTERVAL_DAY_SECOND" => SqlTypeName::INTERVAL_DAY_SECOND, + "INTERVAL_HOUR" => SqlTypeName::INTERVAL_HOUR, + "INTERVAL_HOUR_MINUTE" => SqlTypeName::INTERVAL_HOUR_MINUTE, + "INTERVAL_HOUR_SECOND" => SqlTypeName::INTERVAL_HOUR_SECOND, + "INTERVAL_MINUTE" => SqlTypeName::INTERVAL_MINUTE, + "INTERVAL_MINUTE_SECOND" => SqlTypeName::INTERVAL_MINUTE_SECOND, "INTERVAL_MONTH" => SqlTypeName::INTERVAL_MONTH, + "INTERVAL_SECOND" => SqlTypeName::INTERVAL_SECOND, + "INTERVAL_YEAR" => SqlTypeName::INTERVAL_YEAR, + "INTERVAL_YEAR_MONTH" => SqlTypeName::INTERVAL_YEAR_MONTH, + "MAP" => SqlTypeName::MAP, + "MULTISET" => SqlTypeName::MULTISET, + "OTHER" => SqlTypeName::OTHER, + "ROW" => SqlTypeName::ROW, + "SARG" => SqlTypeName::SARG, "BINARY" => SqlTypeName::BINARY, "VARBINARY" => SqlTypeName::VARBINARY, "CHAR" => SqlTypeName::CHAR, "VARCHAR" => SqlTypeName::VARCHAR, "STRUCTURED" => SqlTypeName::STRUCTURED, + "SYMBOL" => SqlTypeName::SYMBOL, "DECIMAL" => SqlTypeName::DECIMAL, - "MAP" => SqlTypeName::MAP, - _ => unimplemented!(), + "DYNAMIC_STAT" => SqlTypeName::DYNAMIC_STAR, + "UNKNOWN" => SqlTypeName::UNKNOWN, + _ => unimplemented!("SqlTypeName::from_string() for str type: {:?}", input_type), } } } From 3d1a5ad934d112327395d085369307605f4590de Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 25 Apr 2022 10:23:36 -0400 Subject: [PATCH 29/61] accept INT as INTEGER --- dask_planner/src/sql/types.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index 26bda2377..618786d88 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -265,6 +265,7 @@ impl SqlTypeName { "CURSOR" => SqlTypeName::CURSOR, "TINYINT" => SqlTypeName::TINYINT, "SMALLINT" => SqlTypeName::SMALLINT, + "INT" => SqlTypeName::INTEGER, "INTEGER" => SqlTypeName::INTEGER, "BIGINT" => SqlTypeName::BIGINT, "REAL" => SqlTypeName::REAL, From 384e446e692c892478929d42987b51e2bf4b3fa7 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 25 Apr 2022 14:20:01 -0400 Subject: [PATCH 30/61] tests update --- dask_planner/src/expression.rs | 64 ++++++++++++---------- dask_planner/src/sql/logical/projection.rs | 11 +--- dask_sql/physical/rel/logical/project.py | 8 ++- dask_sql/physical/rex/core/input_ref.py | 6 +- tests/integration/test_select.py | 11 ++-- 5 files changed, 54 insertions(+), 46 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 8bb3b68c5..a81ce7641 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -96,31 +96,25 @@ impl PyExpr { "AGGREGATE COLUMN IS {}", col.name ); - col.name.clone() + col.name.clone().to_ascii_uppercase() } - _ => name.clone(), + _ => name.clone().to_ascii_uppercase(), } } - _ => name.clone(), + _ => name.clone().to_ascii_uppercase(), } } _ => { println!("Encountered a non-Aggregate type"); - name.clone() + name.clone().to_ascii_uppercase() } } } - _ => name.clone(), + _ => name.clone().to_ascii_uppercase(), } } - // Expr::Case { expr, when_then_expr, else_expr } => { - // println!("expr: {:?}", expr); - // println!("when_then_expr: {:?}", when_then_expr); - // println!("else_expr: {:?}", else_expr); - // panic!("Case WHEN BABY!!!"); - // }, - _ => name.clone(), + _ => name.clone().to_ascii_uppercase(), } } Expr::Column(column) => column.name.clone(), @@ -173,25 +167,37 @@ impl PyExpr { } } + // /// Gets the positional index of the Expr instance from the LogicalPlan DFSchema + // #[pyo3(name = "getIndex")] + // pub fn index(&self, input_plan: LogicalPlan) -> PyResult { + // // let input: &Option> = &self.input_plan; + // match input_plan { + // Some(plan) => { + // let name: Result = self.expr.name(plan.schema()); + // match name { + // Ok(fq_name) => Ok(plan + // .schema() + // .index_of_column(&Column::from_qualified_name(&fq_name)) + // .unwrap()), + // Err(e) => panic!("{:?}", e), + // } + // } + // None => { + // panic!("We need a valid LogicalPlan instance to get the Expr's index in the schema") + // } + // } + // } + /// Gets the positional index of the Expr instance from the LogicalPlan DFSchema #[pyo3(name = "getIndex")] - pub fn index(&self) -> PyResult { - let input: &Option> = &self.input_plan; - match input { - Some(plan) => { - let name: Result = self.expr.name(plan.schema()); - match name { - Ok(fq_name) => Ok(plan - .schema() - .index_of_column(&Column::from_qualified_name(&fq_name)) - .unwrap()), - Err(e) => panic!("{:?}", e), - } - } - None => { - panic!("We need a valid LogicalPlan instance to get the Expr's index in the schema") - } - } + pub fn index(&self, input_plan: logical::PyLogicalPlan) -> PyResult { + let fq_name: Result = + self.expr.name(input_plan.original_plan.schema()); + Ok(input_plan + .original_plan + .schema() + .index_of_column(&Column::from_qualified_name(&fq_name.unwrap())) + .unwrap()) } /// Examine the current/"self" PyExpr and return its "type" diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index 9caee8e87..0380ee9bf 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -17,7 +17,7 @@ impl PyProjection { fn column_name(&mut self, expr: PyExpr) -> PyResult { let mut val: String = String::from("OK"); match expr.expr { - Expr::Alias(expr, _alias) => match expr.as_ref() { + Expr::Alias(expr, name) => match expr.as_ref() { Expr::Column(col) => { let index = self.projection.input.schema().index_of_column(col).unwrap(); match self.projection.input.as_ref() { @@ -44,14 +44,7 @@ impl PyProjection { val = self.column_name(ex_type.into()).unwrap(); println!("Setting col name to: {:?}", val); } - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - println!("Case WHEN BABY!!!"); - } - _ => panic!("not supported: {:?}", expr), + _ => val = name.clone().to_ascii_uppercase(), }, Expr::Column(col) => val = col.name.clone(), Expr::Cast { expr, data_type: _ } => { diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index 0226acc1e..9cb1223a1 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -38,6 +38,8 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai new_columns = {} new_mappings = {} + print(f"Named Projects: {named_projects} df columns: {df.columns}") + # Collect all (new) columns this Projection will limit to for key, expr in named_projects: @@ -54,13 +56,15 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai # shortcut: if we have a column already, there is no need to re-assign it again # this is only the case if the expr is a RexInputRef if expr.getRexType() == RexType.Reference: - index = expr.getIndex() + print(f"Reference for Expr: {expr}") + index = expr.getIndex(rel) backend_column_name = cc.get_backend_by_frontend_index(index) logger.debug( f"Not re-adding the same column {key} (but just referencing it)" ) new_mappings[key] = backend_column_name else: + print(f"Other for Expr: {expr}") random_name = new_temporary_column(df) new_columns[random_name] = RexConverter.convert( expr, dc, context=context @@ -68,6 +72,8 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai logger.debug(f"Adding a new column {key} out of {expr}") new_mappings[key] = random_name + print(f"Projecting columns: {column_names}") + # Actually add the new columns if new_columns: df = df.assign(**new_columns) diff --git a/dask_sql/physical/rex/core/input_ref.py b/dask_sql/physical/rex/core/input_ref.py index 152b24caf..66a45654d 100644 --- a/dask_sql/physical/rex/core/input_ref.py +++ b/dask_sql/physical/rex/core/input_ref.py @@ -27,7 +27,9 @@ def convert( context: "dask_sql.Context", ) -> dd.Series: df = dc.df + cc = dc.column_container # The column is references by index - column_name = str(expr.column_name(rel)) - return df[column_name] + index = expr.getIndex(rel) + backend_column_name = cc.get_backend_by_frontend_index(index) + return df[backend_column_name] diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index d67d4b253..694451066 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -1,10 +1,12 @@ -# import numpy as np +import numpy as np import pandas as pd -import pytest # from dask_sql.utils import ParsingException from tests.utils import assert_eq +# import pytest + + # def test_select(c, df): # result_df = c.sql("SELECT * FROM df") @@ -207,7 +209,6 @@ # assert_eq(result_df, expected_df) -# @pytest.mark.skip(reason="WIP DataFusion") def test_multi_case_when(c): df = pd.DataFrame({"a": [1, 6, 7, 8, 9]}) c.create_table("df", df) @@ -215,10 +216,10 @@ def test_multi_case_when(c): actual_df = c.sql( """ SELECT - CASE WHEN a BETWEEN 6 AND 8 THEN 1 ELSE 0 END AS C + CASE WHEN a BETWEEN 6 AND 8 THEN 1 ELSE 0 END AS "C" FROM df """ ) - expected_df = pd.DataFrame({"C": [0, 1, 1, 1, 0]}, dtype=np.int32) + expected_df = pd.DataFrame({"C": [0, 1, 1, 1, 0]}, dtype=np.int64) assert_eq(actual_df, expected_df) From 199b9d2ef5bd1307ba91966d40b118c88141129f Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 25 Apr 2022 16:41:45 -0400 Subject: [PATCH 31/61] checkpoint --- dask_planner/src/expression.rs | 29 +++++++++++++++++++++++- dask_sql/physical/rel/logical/project.py | 2 -- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index a81ce7641..835937853 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -148,6 +148,33 @@ impl PyExpr { _ => panic!("Nothing found!!!"), } } + + fn _rex_type(&self, expr: Expr) -> RexType { + match &expr { + Expr::Alias(..) => self._rex_type(expr.input), + Expr::Column(..) => RexType::Reference, + Expr::ScalarVariable(..) => RexType::Literal, + Expr::Literal(..) => RexType::Literal, + Expr::BinaryExpr { .. } => RexType::Call, + Expr::Not(..) => RexType::Call, + Expr::IsNotNull(..) => RexType::Call, + Expr::Negative(..) => RexType::Call, + Expr::GetIndexedField { .. } => RexType::Reference, + Expr::IsNull(..) => RexType::Call, + Expr::Between { .. } => RexType::Call, + Expr::Case { .. } => RexType::Call, + Expr::Cast { .. } => RexType::Call, + Expr::TryCast { .. } => RexType::Call, + Expr::Sort { .. } => RexType::Call, + Expr::ScalarFunction { .. } => RexType::Call, + Expr::AggregateFunction { .. } => RexType::Call, + Expr::WindowFunction { .. } => RexType::Call, + Expr::AggregateUDF { .. } => RexType::Call, + Expr::InList { .. } => RexType::Call, + Expr::Wildcard => RexType::Call, + _ => RexType::Other, + } + } } #[pymethods] @@ -236,7 +263,7 @@ impl PyExpr { #[pyo3(name = "getRexType")] pub fn rex_type(&self) -> RexType { match &self.expr { - Expr::Alias(..) => RexType::Reference, + Expr::Alias(..) => self.rex_type(), Expr::Column(..) => RexType::Reference, Expr::ScalarVariable(..) => RexType::Literal, Expr::Literal(..) => RexType::Literal, diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index 9cb1223a1..ecaaa0ccd 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -38,8 +38,6 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai new_columns = {} new_mappings = {} - print(f"Named Projects: {named_projects} df columns: {df.columns}") - # Collect all (new) columns this Projection will limit to for key, expr in named_projects: From c9dffaec91f38929d7c26330d73f80c62d19885c Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Wed, 27 Apr 2022 09:04:43 -0400 Subject: [PATCH 32/61] checkpoint --- dask_planner/Cargo.toml | 4 ++-- dask_planner/src/expression.rs | 24 ++++-------------------- dask_planner/src/sql/logical.rs | 1 + dask_sql/physical/rel/logical/project.py | 3 ++- dask_sql/physical/rex/convert.py | 2 +- dask_sql/physical/rex/core/call.py | 3 +++ dask_sql/physical/rex/core/input_ref.py | 4 ++-- 7 files changed, 15 insertions(+), 26 deletions(-) diff --git a/dask_planner/Cargo.toml b/dask_planner/Cargo.toml index bc7e3138a..c797c354b 100644 --- a/dask_planner/Cargo.toml +++ b/dask_planner/Cargo.toml @@ -12,8 +12,8 @@ rust-version = "1.59" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } rand = "0.7" pyo3 = { version = "0.16", features = ["extension-module", "abi3", "abi3-py38"] } -datafusion = { git="https://github.com/apache/arrow-datafusion/", rev = "23f1c77569d1f3b0ff42ade56f9b2afb53d44292" } -datafusion-expr = { git="https://github.com/apache/arrow-datafusion/", rev = "23f1c77569d1f3b0ff42ade56f9b2afb53d44292" } +datafusion = { git="https://github.com/apache/arrow-datafusion/", rev = "6ae7d9599813b3aaf72b22a0d18f4d27bef0f730" } +datafusion-expr = { git="https://github.com/apache/arrow-datafusion/", rev = "6ae7d9599813b3aaf72b22a0d18f4d27bef0f730" } uuid = { version = "0.8", features = ["v4"] } mimalloc = { version = "*", default-features = false } sqlparser = "0.14.0" diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 835937853..ef0500d04 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -194,26 +194,10 @@ impl PyExpr { } } - // /// Gets the positional index of the Expr instance from the LogicalPlan DFSchema - // #[pyo3(name = "getIndex")] - // pub fn index(&self, input_plan: LogicalPlan) -> PyResult { - // // let input: &Option> = &self.input_plan; - // match input_plan { - // Some(plan) => { - // let name: Result = self.expr.name(plan.schema()); - // match name { - // Ok(fq_name) => Ok(plan - // .schema() - // .index_of_column(&Column::from_qualified_name(&fq_name)) - // .unwrap()), - // Err(e) => panic!("{:?}", e), - // } - // } - // None => { - // panic!("We need a valid LogicalPlan instance to get the Expr's index in the schema") - // } - // } - // } + #[pyo3(name = "toString")] + pub fn to_string(&self) -> PyResult { + Ok(format!("{}", &self.expr)) + } /// Gets the positional index of the Expr instance from the LogicalPlan DFSchema #[pyo3(name = "getIndex")] diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index 0c9ca27f5..2359afdf2 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -122,6 +122,7 @@ impl PyLogicalPlan { LogicalPlan::Explain(_explain) => "Explain", LogicalPlan::Analyze(_analyze) => "Analyze", LogicalPlan::Extension(_extension) => "Extension", + LogicalPlan::Subquery(_sub_query) => "Subquery", LogicalPlan::SubqueryAlias(_sqalias) => "SubqueryAlias", LogicalPlan::CreateCatalogSchema(_create) => "CreateCatalogSchema", LogicalPlan::CreateCatalog(_create_catalog) => "CreateCatalog", diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index ecaaa0ccd..ed6b3f0ce 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -43,12 +43,13 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai key = str(key) column_names.append(key) - random_name = new_temporary_column(df) new_columns[random_name] = RexConverter.convert( rel, expr, dc, context=context ) + breakpoint() + new_mappings[key] = random_name # shortcut: if we have a column already, there is no need to re-assign it again diff --git a/dask_sql/physical/rex/convert.py b/dask_sql/physical/rex/convert.py index 1123e8359..f83c9bd40 100644 --- a/dask_sql/physical/rex/convert.py +++ b/dask_sql/physical/rex/convert.py @@ -55,7 +55,7 @@ def convert( context: "dask_sql.Context", ) -> Union[dd.DataFrame, Any]: """ - Convert the given rel (java instance) + Convert the given Expression into a python expression (a dask dataframe) using the stored plugins and the dictionary of registered dask tables. diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 68c941c30..4e6724b06 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -224,6 +224,8 @@ def cast(self, operand, rex=None) -> SeriesOrScalar: if not is_frame(operand): # pragma: no cover return operand + breakpoint() + output_type = str(rex.getType()) python_type = sql_to_python_type( output_type=sql_to_python_type(output_type.upper()) @@ -891,6 +893,7 @@ def convert( logger.debug(f"Operator Name: {operator_name}") try: + breakpoint() operation = self.OPERATION_MAPPING[operator_name] except KeyError: try: diff --git a/dask_sql/physical/rex/core/input_ref.py b/dask_sql/physical/rex/core/input_ref.py index 66a45654d..4272c832e 100644 --- a/dask_sql/physical/rex/core/input_ref.py +++ b/dask_sql/physical/rex/core/input_ref.py @@ -22,7 +22,7 @@ class RexInputRefPlugin(BaseRexPlugin): def convert( self, rel: "LogicalPlan", - expr: "Expression", + rex: "Expression", dc: DataContainer, context: "dask_sql.Context", ) -> dd.Series: @@ -30,6 +30,6 @@ def convert( cc = dc.column_container # The column is references by index - index = expr.getIndex(rel) + index = rex.getIndex() backend_column_name = cc.get_backend_by_frontend_index(index) return df[backend_column_name] From c9aad43fe1aafbb7a0c71a55277e5cb37d8e6f50 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Wed, 27 Apr 2022 22:56:16 -0400 Subject: [PATCH 33/61] Refactor PyExpr by removing From trait, and using recursion to expand expression list for rex calls --- dask_planner/src/expression.rs | 57 +++++--- dask_planner/src/sql/logical/projection.rs | 50 ++++--- dask_sql/physical/rel/logical/project.py | 6 +- dask_sql/physical/rex/convert.py | 22 ++- dask_sql/physical/rex/core/call.py | 10 +- tests/integration/test_select.py | 160 +++++++++++---------- 6 files changed, 173 insertions(+), 132 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index ef0500d04..b3449e623 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -31,15 +31,6 @@ impl From for Expr { } } -impl From for PyExpr { - fn from(expr: Expr) -> PyExpr { - PyExpr { - input_plan: None, - expr: expr, - } - } -} - #[pyclass(name = "ScalarValue", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyScalarValue { @@ -181,7 +172,7 @@ impl PyExpr { impl PyExpr { #[staticmethod] pub fn literal(value: PyScalarValue) -> PyExpr { - lit(value.scalar_value).into() + PyExpr::from(lit(value.scalar_value), None) } /// If this Expression instances references an existing @@ -201,14 +192,25 @@ impl PyExpr { /// Gets the positional index of the Expr instance from the LogicalPlan DFSchema #[pyo3(name = "getIndex")] - pub fn index(&self, input_plan: logical::PyLogicalPlan) -> PyResult { - let fq_name: Result = - self.expr.name(input_plan.original_plan.schema()); - Ok(input_plan - .original_plan - .schema() - .index_of_column(&Column::from_qualified_name(&fq_name.unwrap())) - .unwrap()) + pub fn index(&self) -> PyResult { + println!("&self: {:?}", &self); + println!("&self.input_plan: {:?}", self.input_plan); + let input: &Option> = &self.input_plan; + match input { + Some(plan) => { + let name: Result = self.expr.name(plan.schema()); + match name { + Ok(fq_name) => Ok(plan + .schema() + .index_of_column(&Column::from_qualified_name(&fq_name)) + .unwrap()), + Err(e) => panic!("{:?}", e), + } + } + None => { + panic!("We need a valid LogicalPlan instance to get the Expr's index in the schema") + } + } } /// Examine the current/"self" PyExpr and return its "type" @@ -247,7 +249,14 @@ impl PyExpr { #[pyo3(name = "getRexType")] pub fn rex_type(&self) -> RexType { match &self.expr { +<<<<<<< HEAD Expr::Alias(..) => self.rex_type(), +======= + Expr::Alias(expr, name) => { + println!("expr: {:?}", *expr); + RexType::Reference + } +>>>>>>> Refactor PyExpr by removing From trait, and using recursion to expand expression list for rex calls Expr::Column(..) => RexType::Reference, Expr::ScalarVariable(..) => RexType::Literal, Expr::Literal(..) => RexType::Literal, @@ -284,22 +293,26 @@ impl PyExpr { Expr::BinaryExpr { left, op: _, right } => { let mut operands: Vec = Vec::new(); let left_desc: Expr = *left.clone(); - operands.push(left_desc.into()); + let py_left: PyExpr = PyExpr::from(left_desc, self.input_plan.clone()); + operands.push(py_left); let right_desc: Expr = *right.clone(); - operands.push(right_desc.into()); + let py_right: PyExpr = PyExpr::from(right_desc, self.input_plan.clone()); + operands.push(py_right); Ok(operands) } Expr::ScalarFunction { fun: _, args } => { let mut operands: Vec = Vec::new(); for arg in args { - operands.push(arg.clone().into()); + let py_arg: PyExpr = PyExpr::from(arg.clone(), self.input_plan.clone()); + operands.push(py_arg); } Ok(operands) } Expr::Cast { expr, data_type: _ } => { let mut operands: Vec = Vec::new(); let ex: Expr = *expr.clone(); - operands.push(ex.into()); + let py_ex: PyExpr = PyExpr::from(ex, self.input_plan.clone()); + operands.push(py_ex); Ok(operands) } _ => Err(PyErr::new::( diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index 0380ee9bf..b37cb3805 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -11,6 +11,23 @@ pub struct PyProjection { pub(crate) projection: Projection, } +impl PyProjection { + /// Projection: Gets the names of the fields that should be projected + fn projected_expressions(&mut self, local_expr: &PyExpr) -> Vec { + let mut projs: Vec = Vec::new(); + match &local_expr.expr { + Expr::Alias(expr, _name) => { + let ex: Expr = *expr.clone(); + let mut py_expr: PyExpr = PyExpr::from(ex, Some(self.projection.input.clone())); + py_expr.input_plan = local_expr.input_plan.clone(); + projs.extend_from_slice(self.projected_expressions(&py_expr).as_slice()); + } + _ => projs.push(local_expr.clone()), + } + projs + } +} + #[pymethods] impl PyProjection { #[pyo3(name = "getColumnName")] @@ -41,7 +58,9 @@ impl PyProjection { } Expr::Cast { expr, data_type: _ } => { let ex_type: Expr = *expr.clone(); - val = self.column_name(ex_type.into()).unwrap(); + let py_type: PyExpr = + PyExpr::from(ex_type, Some(self.projection.input.clone())); + val = self.column_name(py_type).unwrap(); println!("Setting col name to: {:?}", val); } _ => val = name.clone().to_ascii_uppercase(), @@ -49,7 +68,8 @@ impl PyProjection { Expr::Column(col) => val = col.name.clone(), Expr::Cast { expr, data_type: _ } => { let ex_type: Expr = *expr; - val = self.column_name(ex_type.into()).unwrap() + let py_type: PyExpr = PyExpr::from(ex_type, Some(self.projection.input.clone())); + val = self.column_name(py_type).unwrap() } _ => { panic!( @@ -61,25 +81,19 @@ impl PyProjection { Ok(val) } - /// Projection: Gets the names of the fields that should be projected - #[pyo3(name = "getProjectedExpressions")] - fn projected_expressions(&mut self) -> PyResult> { - let mut projs: Vec = Vec::new(); - for expr in &self.projection.expr { - projs.push(PyExpr::from( - expr.clone(), - Some(self.projection.input.clone()), - )); - } - Ok(projs) - } - #[pyo3(name = "getNamedProjects")] fn named_projects(&mut self) -> PyResult> { let mut named: Vec<(String, PyExpr)> = Vec::new(); - for expr in &self.projected_expressions().unwrap() { - let name: String = self.column_name(expr.clone()).unwrap(); - named.push((name, expr.clone())); + println!("Projection Input: {:?}", &self.projection.input); + for expression in self.projection.expr.clone() { + let mut py_expr: PyExpr = PyExpr::from(expression, Some(self.projection.input.clone())); + py_expr.input_plan = Some(self.projection.input.clone()); + println!("Expression Input: {:?}", &py_expr.input_plan); + for expr in self.projected_expressions(&py_expr) { + let name: String = self.column_name(expr.clone()).unwrap(); + println!("Named Project: {:?} - Expr: {:?}", &name, &expr); + named.push((name, expr.clone())); + } } Ok(named) } diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index ed6b3f0ce..cbb9849b0 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -41,6 +41,8 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai # Collect all (new) columns this Projection will limit to for key, expr in named_projects: + print(f"Key: {key} - Expr: {expr.toString()}") + key = str(key) column_names.append(key) random_name = new_temporary_column(df) @@ -48,8 +50,6 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai rel, expr, dc, context=context ) - breakpoint() - new_mappings[key] = random_name # shortcut: if we have a column already, there is no need to re-assign it again @@ -66,7 +66,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai print(f"Other for Expr: {expr}") random_name = new_temporary_column(df) new_columns[random_name] = RexConverter.convert( - expr, dc, context=context + rel, expr, dc, context=context ) logger.debug(f"Adding a new column {key} out of {expr}") new_mappings[key] = random_name diff --git a/dask_sql/physical/rex/convert.py b/dask_sql/physical/rex/convert.py index f83c9bd40..33e571d25 100644 --- a/dask_sql/physical/rex/convert.py +++ b/dask_sql/physical/rex/convert.py @@ -14,13 +14,19 @@ logger = logging.getLogger(__name__) +# _REX_TYPE_TO_PLUGIN = { +# "Alias": "InputRef", +# "Column": "InputRef", +# "BinaryExpr": "RexCall", +# "Literal": "RexLiteral", +# "ScalarFunction": "RexCall", +# "Cast": "RexCall", +# } + _REX_TYPE_TO_PLUGIN = { - "Alias": "InputRef", - "Column": "InputRef", - "BinaryExpr": "RexCall", - "Literal": "RexLiteral", - "ScalarFunction": "RexCall", - "Cast": "RexCall", + "RexType.Reference": "InputRef", + "RexType.Call": "RexCall", + "RexType.Literal": "RexLiteral", } @@ -60,7 +66,7 @@ def convert( using the stored plugins and the dictionary of registered dask tables. """ - expr_type = _REX_TYPE_TO_PLUGIN[rex.getExprType()] + expr_type = _REX_TYPE_TO_PLUGIN[str(rex.getRexType())] try: plugin_instance = cls.get_plugin(expr_type) @@ -73,6 +79,8 @@ def convert( f"Processing REX {rex} using {plugin_instance.__class__.__name__}..." ) + print(f"expr_type: {expr_type} - Expr: {rex.toString()}") + df = plugin_instance.convert(rel, rex, dc, context=context) logger.debug(f"Processed REX {rex} into {LoggableDataFrame(df)}") return df diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 4e6724b06..6486b2bd2 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -224,12 +224,8 @@ def cast(self, operand, rex=None) -> SeriesOrScalar: if not is_frame(operand): # pragma: no cover return operand - breakpoint() - output_type = str(rex.getType()) - python_type = sql_to_python_type( - output_type=sql_to_python_type(output_type.upper()) - ) + python_type = sql_to_python_type(SqlTypeName.fromString(output_type.upper())) return_column = cast_column_to_type(operand, python_type) @@ -878,6 +874,9 @@ def convert( context: "dask_sql.Context", ) -> SeriesOrScalar: logger.debug(f"Expression Operands: {expr.getOperands()}") + print( + f"Expr: {expr.toString()} - # Operands: {len(expr.getOperands())} - Operands[0]: {expr.getOperands()[0].toString()}" + ) # Prepare the operands by turning the RexNodes into python expressions operands = [ RexConverter.convert(rel, o, dc, context=context) @@ -893,7 +892,6 @@ def convert( logger.debug(f"Operator Name: {operator_name}") try: - breakpoint() operation = self.OPERATION_MAPPING[operator_name] except KeyError: try: diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index 694451066..f24086284 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -1,16 +1,26 @@ import numpy as np import pandas as pd +<<<<<<< HEAD +======= +import pytest +>>>>>>> Refactor PyExpr by removing From trait, and using recursion to expand expression list for rex calls -# from dask_sql.utils import ParsingException +from dask_sql.utils import ParsingException from tests.utils import assert_eq +<<<<<<< HEAD # import pytest # def test_select(c, df): # result_df = c.sql("SELECT * FROM df") +======= +>>>>>>> Refactor PyExpr by removing From trait, and using recursion to expand expression list for rex calls -# assert_eq(result_df, df) +def test_select(c, df): + result_df = c.sql("SELECT * FROM df") + + assert_eq(result_df, df) # @pytest.mark.skip(reason="WIP DataFusion") @@ -24,30 +34,30 @@ # assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) -# def test_select_column(c, df): -# result_df = c.sql("SELECT a FROM df") +def test_select_column(c, df): + result_df = c.sql("SELECT a FROM df") -# assert_eq(result_df, df[["a"]]) + assert_eq(result_df, df[["a"]]) -# def test_select_different_types(c): -# expected_df = pd.DataFrame( -# { -# "date": pd.to_datetime(["2022-01-21 17:34", "2022-01-21", "17:34", pd.NaT]), -# "string": ["this is a test", "another test", "äölüć", ""], -# "integer": [1, 2, -4, 5], -# "float": [-1.1, np.NaN, pd.NA, np.sqrt(2)], -# } -# ) -# c.create_table("df", expected_df) -# result_df = c.sql( -# """ -# SELECT * -# FROM df -# """ -# ) +def test_select_different_types(c): + expected_df = pd.DataFrame( + { + "date": pd.to_datetime(["2022-01-21 17:34", "2022-01-21", "17:34", pd.NaT]), + "string": ["this is a test", "another test", "äölüć", ""], + "integer": [1, 2, -4, 5], + "float": [-1.1, np.NaN, pd.NA, np.sqrt(2)], + } + ) + c.create_table("df", expected_df) + result_df = c.sql( + """ + SELECT * + FROM df + """ + ) -# assert_eq(result_df, expected_df) + assert_eq(result_df, expected_df) # @pytest.mark.skip(reason="WIP DataFusion") @@ -104,25 +114,25 @@ # assert_eq(result_df, expected_df) -# def test_wrong_input(c): -# with pytest.raises(ParsingException): -# c.sql("""SELECT x FROM df""") +def test_wrong_input(c): + with pytest.raises(ParsingException): + c.sql("""SELECT x FROM df""") -# with pytest.raises(ParsingException): -# c.sql("""SELECT x FROM df""") + with pytest.raises(ParsingException): + c.sql("""SELECT x FROM df""") -# def test_timezones(c, datetime_table): -# result_df = c.sql( -# """ -# SELECT * FROM datetime_table -# """ -# ) +def test_timezones(c, datetime_table): + result_df = c.sql( + """ + SELECT * FROM datetime_table + """ + ) -# print(f"Expected DF: \n{datetime_table.head(10)}\n") -# print(f"\nResult DF: \n{result_df.head(10)}") + print(f"Expected DF: \n{datetime_table.head(10)}\n") + print(f"\nResult DF: \n{result_df.head(10)}") -# assert_eq(result_df, datetime_table) + assert_eq(result_df, datetime_table) # @pytest.mark.skip(reason="WIP DataFusion") @@ -148,25 +158,24 @@ # assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) -# # @pytest.mark.skip(reason="WIP DataFusion") -# @pytest.mark.parametrize( -# "input_table", -# [ -# "datetime_table", -# # pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), -# ], -# ) -# def test_date_casting(c, input_table, request): -# datetime_table = request.getfixturevalue(input_table) -# result_df = c.sql( -# f""" -# SELECT -# CAST(timezone AS DATE) AS timezone, -# CAST(no_timezone AS DATE) AS no_timezone, -# CAST(utc_timezone AS DATE) AS utc_timezone -# FROM {input_table} -# """ -# ) +@pytest.mark.parametrize( + "input_table", + [ + "datetime_table", + # pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), + ], +) +def test_date_casting(c, input_table, request): + datetime_table = request.getfixturevalue(input_table) + result_df = c.sql( + f""" + SELECT + CAST(timezone AS DATE) AS timezone, + CAST(no_timezone AS DATE) AS no_timezone, + CAST(utc_timezone AS DATE) AS utc_timezone + FROM {input_table} + """ + ) # expected_df = datetime_table # expected_df["timezone"] = ( @@ -185,28 +194,27 @@ # assert_eq(result_df, expected_df) -# @pytest.mark.skip(reason="WIP DataFusion") -# @pytest.mark.parametrize( -# "input_table", -# [ -# "datetime_table", -# pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), -# ], -# ) -# def test_timestamp_casting(c, input_table, request): -# datetime_table = request.getfixturevalue(input_table) -# result_df = c.sql( -# f""" -# SELECT -# CAST(timezone AS TIMESTAMP) AS timezone, -# CAST(no_timezone AS TIMESTAMP) AS no_timezone, -# CAST(utc_timezone AS TIMESTAMP) AS utc_timezone -# FROM {input_table} -# """ -# ) +@pytest.mark.parametrize( + "input_table", + [ + "datetime_table", + pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), + ], +) +def test_timestamp_casting(c, input_table, request): + datetime_table = request.getfixturevalue(input_table) + result_df = c.sql( + f""" + SELECT + CAST(timezone AS TIMESTAMP) AS timezone, + CAST(no_timezone AS TIMESTAMP) AS no_timezone, + CAST(utc_timezone AS TIMESTAMP) AS utc_timezone + FROM {input_table} + """ + ) -# expected_df = datetime_table.astype(" Date: Thu, 28 Apr 2022 08:01:08 -0400 Subject: [PATCH 34/61] skip test that uses create statement for gpuci --- tests/unit/test_context.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 12e85b69c..269713915 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -116,6 +116,7 @@ def test_sql(gpu): assert_eq(result, data_frame) +@pytest.mark.skip(reason="WIP DataFusion - missing create statement logic") @pytest.mark.parametrize( "gpu", [ From 643e85dbd943f0975936ce822caebe0429530efd Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 28 Apr 2022 14:03:26 -0400 Subject: [PATCH 35/61] Basic DataFusion Select Functionality (#489) * Condition for BinaryExpr, filter, input_ref, rexcall, and rexliteral * Updates for test_filter * more of test_filter.py working with the exception of some date pytests * Add workflow to keep datafusion dev branch up to date (#440) * Include setuptools-rust in conda build recipie, in host and run * Remove PyArrow dependency * rebase with datafusion-sql-planner * refactor changes that were inadvertent during rebase * timestamp with loglca time zone * Include RelDataType work * Include RelDataType work * Introduced SqlTypeName Enum in Rust and mappings for Python * impl PyExpr.getIndex() * add getRowType() for logical.rs * Introduce DaskTypeMap for storing correlating SqlTypeName and DataTypes * use str values instead of Rust Enums, Python is unable to Hash the Rust Enums if used in a dict * linter changes, why did that work on my local pre-commit?? * linter changes, why did that work on my local pre-commit?? * Convert final strs to SqlTypeName Enum * removed a few print statements * Temporarily disable conda run_test.py script since it uses features not yet implemented * expose fromString method for SqlTypeName to use Enums instead of strings for type checking * expanded SqlTypeName from_string() support * accept INT as INTEGER * Remove print statements * Default to UTC if tz is None * Delegate timezone handling to the arrow library * Updates from review Co-authored-by: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> --- dask_planner/src/expression.rs | 160 ++++++++++++++++++++++- dask_sql/physical/rel/logical/project.py | 11 +- tests/integration/test_select.py | 11 -- 3 files changed, 156 insertions(+), 26 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index b3449e623..4db1afcf2 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -249,14 +249,57 @@ impl PyExpr { #[pyo3(name = "getRexType")] pub fn rex_type(&self) -> RexType { match &self.expr { -<<<<<<< HEAD - Expr::Alias(..) => self.rex_type(), -======= Expr::Alias(expr, name) => { - println!("expr: {:?}", *expr); - RexType::Reference + println!("Alias encountered with name: {:?}", name); + // let reference: Expr = *expr.as_ref(); + // let plan: logical::PyLogicalPlan = reference.input().clone().into(); + + // Only certain LogicalPlan variants are valid in this nested Alias scenario so we + // extract the valid ones and error on the invalid ones + match expr.as_ref() { + Expr::Column(col) => { + // First we must iterate the current node before getting its input + match plan { + LogicalPlan::Projection(proj) => { + match proj.input.as_ref() { + LogicalPlan::Aggregate(agg) => { + let mut exprs = agg.group_expr.clone(); + exprs.extend_from_slice(&agg.aggr_expr); + let col_index: usize = + proj.input.schema().index_of_column(col).unwrap(); + // match &exprs[plan.get_index(col)] { + match &exprs[col_index] { + Expr::AggregateFunction { args, .. } => { + match &args[0] { + Expr::Column(col) => { + println!( + "AGGREGATE COLUMN IS {}", + col.name + ); + col.name.clone() + } + _ => name.clone(), + } + } + _ => name.clone(), + } + } + _ => { + println!("Encountered a non-Aggregate type"); + + name.clone() + } + } + } + _ => name.clone(), + } + } + _ => { + println!("Encountered a non Expr::Column instance"); + name.clone() + } + } } ->>>>>>> Refactor PyExpr by removing From trait, and using recursion to expand expression list for rex calls Expr::Column(..) => RexType::Reference, Expr::ScalarVariable(..) => RexType::Literal, Expr::Literal(..) => RexType::Literal, @@ -280,6 +323,111 @@ impl PyExpr { _ => RexType::Other, } } +} + +#[pymethods] +impl PyExpr { + #[staticmethod] + pub fn literal(value: PyScalarValue) -> PyExpr { + lit(value.scalar_value).into() + } + + /// If this Expression instances references an existing + /// Column in the SQL parse tree or not + #[pyo3(name = "isInputReference")] + pub fn is_input_reference(&self) -> PyResult { + match &self.expr { + Expr::Column(_col) => Ok(true), + _ => Ok(false), + } + } + + /// Gets the positional index of the Expr instance from the LogicalPlan DFSchema + #[pyo3(name = "getIndex")] + pub fn index(&self) -> PyResult { + let input: &Option> = &self.input_plan; + match input { + Some(plan) => { + let name: Result = self.expr.name(plan.schema()); + match name { + Ok(fq_name) => Ok(plan + .schema() + .index_of_column(&Column::from_qualified_name(&fq_name)) + .unwrap()), + Err(e) => panic!("{:?}", e), + } + } + None => { + panic!("We need a valid LogicalPlan instance to get the Expr's index in the schema") + } + } + } + + /// Examine the current/"self" PyExpr and return its "type" + /// In this context a "type" is what Dask-SQL Python + /// RexConverter plugin instance should be invoked to handle + /// the Rex conversion + #[pyo3(name = "getExprType")] + pub fn get_expr_type(&self) -> String { + String::from(match &self.expr { + Expr::Alias(..) => "Alias", + Expr::Column(..) => "Column", + Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), + Expr::Literal(..) => "Literal", + Expr::BinaryExpr { .. } => "BinaryExpr", + Expr::Not(..) => panic!("Not!!!"), + Expr::IsNotNull(..) => panic!("IsNotNull!!!"), + Expr::Negative(..) => panic!("Negative!!!"), + Expr::GetIndexedField { .. } => panic!("GetIndexedField!!!"), + Expr::IsNull(..) => panic!("IsNull!!!"), + Expr::Between { .. } => panic!("Between!!!"), + Expr::Case { .. } => panic!("Case!!!"), + Expr::Cast { .. } => "Cast", + Expr::TryCast { .. } => panic!("TryCast!!!"), + Expr::Sort { .. } => panic!("Sort!!!"), + Expr::ScalarFunction { .. } => "ScalarFunction", + Expr::AggregateFunction { .. } => "AggregateFunction", + Expr::WindowFunction { .. } => panic!("WindowFunction!!!"), + Expr::AggregateUDF { .. } => panic!("AggregateUDF!!!"), + Expr::InList { .. } => panic!("InList!!!"), + Expr::Wildcard => panic!("Wildcard!!!"), + _ => "OTHER", + }) + } + + /// Determines the type of this Expr based on its variant + #[pyo3(name = "getRexType")] + pub fn rex_type(&self) -> RexType { + match &self.expr { + Expr::Alias(..) => RexType::Reference, + Expr::Column(..) => RexType::Reference, + Expr::ScalarVariable(..) => RexType::Literal, + Expr::Literal(..) => RexType::Literal, + Expr::BinaryExpr { .. } => RexType::Call, + Expr::Not(..) => RexType::Call, + Expr::IsNotNull(..) => RexType::Call, + Expr::Negative(..) => RexType::Call, + Expr::GetIndexedField { .. } => RexType::Reference, + Expr::IsNull(..) => RexType::Call, + Expr::Between { .. } => RexType::Call, + Expr::Case { .. } => RexType::Call, + Expr::Cast { .. } => RexType::Call, + Expr::TryCast { .. } => RexType::Call, + Expr::Sort { .. } => RexType::Call, + Expr::ScalarFunction { .. } => RexType::Call, + Expr::AggregateFunction { .. } => RexType::Call, + Expr::WindowFunction { .. } => RexType::Call, + Expr::AggregateUDF { .. } => RexType::Call, + Expr::InList { .. } => RexType::Call, + Expr::Wildcard => RexType::Call, + _ => RexType::Other, + } + } + + /// Python friendly shim code to get the name of a column referenced by an expression + pub fn column_name(&self, mut plan: logical::PyLogicalPlan) -> String { + self._column_name(plan.current_node()) + } /// Python friendly shim code to get the name of a column referenced by an expression pub fn column_name(&self, mut plan: logical::PyLogicalPlan) -> String { diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index cbb9849b0..0441fe486 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -41,10 +41,9 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai # Collect all (new) columns this Projection will limit to for key, expr in named_projects: - print(f"Key: {key} - Expr: {expr.toString()}") - key = str(key) column_names.append(key) + random_name = new_temporary_column(df) new_columns[random_name] = RexConverter.convert( rel, expr, dc, context=context @@ -55,15 +54,13 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai # shortcut: if we have a column already, there is no need to re-assign it again # this is only the case if the expr is a RexInputRef if expr.getRexType() == RexType.Reference: - print(f"Reference for Expr: {expr}") - index = expr.getIndex(rel) + index = expr.getIndex() backend_column_name = cc.get_backend_by_frontend_index(index) logger.debug( f"Not re-adding the same column {key} (but just referencing it)" ) new_mappings[key] = backend_column_name else: - print(f"Other for Expr: {expr}") random_name = new_temporary_column(df) new_columns[random_name] = RexConverter.convert( rel, expr, dc, context=context @@ -71,8 +68,6 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai logger.debug(f"Adding a new column {key} out of {expr}") new_mappings[key] = random_name - print(f"Projecting columns: {column_names}") - # Actually add the new columns if new_columns: df = df.assign(**new_columns) @@ -88,6 +83,4 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai dc = DataContainer(df, cc) dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) - print(f"After Project: {dc.df.head(10)}") - return dc diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index f24086284..cd8c2104f 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -1,21 +1,10 @@ import numpy as np import pandas as pd -<<<<<<< HEAD -======= import pytest ->>>>>>> Refactor PyExpr by removing From trait, and using recursion to expand expression list for rex calls from dask_sql.utils import ParsingException from tests.utils import assert_eq -<<<<<<< HEAD -# import pytest - - -# def test_select(c, df): -# result_df = c.sql("SELECT * FROM df") -======= ->>>>>>> Refactor PyExpr by removing From trait, and using recursion to expand expression list for rex calls def test_select(c, df): result_df = c.sql("SELECT * FROM df") From b36ef16a9f4a9ce4c10cef348c87ae34a390abdc Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 28 Apr 2022 14:58:02 -0400 Subject: [PATCH 36/61] updates for expression --- dask_planner/src/expression.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 4db1afcf2..40568ab75 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -126,7 +126,16 @@ impl PyExpr { Expr::GetIndexedField { .. } => unimplemented!("GetIndexedField!!!"), Expr::IsNull(..) => unimplemented!("IsNull!!!"), Expr::Between { .. } => unimplemented!("Between!!!"), - Expr::Case { .. } => unimplemented!("Case!!!"), + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + println!("expr: {:?}", &expr); + println!("when_then_expr: {:?}", &when_then_expr); + println!("else_expr: {:?}", &else_expr); + unimplemented!("CASE!!!") + } Expr::Cast { .. } => unimplemented!("Cast!!!"), Expr::TryCast { .. } => unimplemented!("TryCast!!!"), Expr::Sort { .. } => unimplemented!("Sort!!!"), From 5c94fbc5c8438c763e85cabd244854e3b46e219f Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 28 Apr 2022 15:15:26 -0400 Subject: [PATCH 37/61] uncommented pytests --- tests/integration/test_select.py | 144 +++++++++++++++---------------- 1 file changed, 72 insertions(+), 72 deletions(-) diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index cd8c2104f..e7fe1c3e4 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -12,15 +12,15 @@ def test_select(c, df): assert_eq(result_df, df) -# @pytest.mark.skip(reason="WIP DataFusion") -# def test_select_alias(c, df): -# result_df = c.sql("SELECT a as b, b as a FROM df") +@pytest.mark.skip(reason="WIP DataFusion") +def test_select_alias(c, df): + result_df = c.sql("SELECT a as b, b as a FROM df") -# expected_df = pd.DataFrame(index=df.index) -# expected_df["b"] = df.a -# expected_df["a"] = df.b + expected_df = pd.DataFrame(index=df.index) + expected_df["b"] = df.a + expected_df["a"] = df.b -# assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) + assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) def test_select_column(c, df): @@ -49,58 +49,58 @@ def test_select_different_types(c): assert_eq(result_df, expected_df) -# @pytest.mark.skip(reason="WIP DataFusion") -# def test_select_expr(c, df): -# result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") -# result_df = result_df +@pytest.mark.skip(reason="WIP DataFusion") +def test_select_expr(c, df): + result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") + result_df = result_df -# expected_df = pd.DataFrame( -# { -# "a": df["a"] + 1, -# "bla": df["b"], -# '"df"."a" - 1': df["a"] - 1, -# } -# ) -# assert_eq(result_df, expected_df) + expected_df = pd.DataFrame( + { + "a": df["a"] + 1, + "bla": df["b"], + '"df"."a" - 1': df["a"] - 1, + } + ) + assert_eq(result_df, expected_df) -# @pytest.mark.skip( -# reason="WIP DataFusion, subquery - https://github.com/apache/arrow-datafusion/issues/2237" -# ) -# def test_select_of_select(c, df): -# result_df = c.sql( -# """ -# SELECT 2*c AS e, d - 1 AS f -# FROM -# ( -# SELECT a - 1 AS c, 2*b AS d -# FROM df -# ) AS "inner" -# """ -# ) +@pytest.mark.skip( + reason="WIP DataFusion, subquery - https://github.com/apache/arrow-datafusion/issues/2237" +) +def test_select_of_select(c, df): + result_df = c.sql( + """ + SELECT 2*c AS e, d - 1 AS f + FROM + ( + SELECT a - 1 AS c, 2*b AS d + FROM df + ) AS "inner" + """ + ) -# expected_df = pd.DataFrame({"e": 2 * (df["a"] - 1), "f": 2 * df["b"] - 1}) -# assert_eq(result_df, expected_df) + expected_df = pd.DataFrame({"e": 2 * (df["a"] - 1), "f": 2 * df["b"] - 1}) + assert_eq(result_df, expected_df) -# @pytest.mark.skip(reason="WIP DataFusion") -# def test_select_of_select_with_casing(c, df): -# result_df = c.sql( -# """ -# SELECT AAA, aaa, aAa -# FROM -# ( -# SELECT a - 1 AS aAa, 2*b AS aaa, a + b AS AAA -# FROM df -# ) AS "inner" -# """ -# ) +@pytest.mark.skip(reason="WIP DataFusion") +def test_select_of_select_with_casing(c, df): + result_df = c.sql( + """ + SELECT AAA, aaa, aAa + FROM + ( + SELECT a - 1 AS aAa, 2*b AS aaa, a + b AS AAA + FROM df + ) AS "inner" + """ + ) -# expected_df = pd.DataFrame( -# {"AAA": df["a"] + df["b"], "aaa": 2 * df["b"], "aAa": df["a"] - 1} -# ) + expected_df = pd.DataFrame( + {"AAA": df["a"] + df["b"], "aaa": 2 * df["b"], "aAa": df["a"] - 1} + ) -# assert_eq(result_df, expected_df) + assert_eq(result_df, expected_df) def test_wrong_input(c): @@ -124,27 +124,27 @@ def test_timezones(c, datetime_table): assert_eq(result_df, datetime_table) -# @pytest.mark.skip(reason="WIP DataFusion") -# @pytest.mark.parametrize( -# "input_table", -# [ -# "long_table", -# pytest.param("gpu_long_table", marks=pytest.mark.gpu), -# ], -# ) -# @pytest.mark.parametrize( -# "limit,offset", -# [(101, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)], -# ) -# def test_limit(c, input_table, limit, offset, request): -# long_table = request.getfixturevalue(input_table) - -# if not limit: -# query = f"SELECT * FROM long_table OFFSET {offset}" -# else: -# query = f"SELECT * FROM long_table LIMIT {limit} OFFSET {offset}" - -# assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) +@pytest.mark.skip(reason="WIP DataFusion") +@pytest.mark.parametrize( + "input_table", + [ + "long_table", + pytest.param("gpu_long_table", marks=pytest.mark.gpu), + ], +) +@pytest.mark.parametrize( + "limit,offset", + [(101, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)], +) +def test_limit(c, input_table, limit, offset, request): + long_table = request.getfixturevalue(input_table) + + if not limit: + query = f"SELECT * FROM long_table OFFSET {offset}" + else: + query = f"SELECT * FROM long_table LIMIT {limit} OFFSET {offset}" + + assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) @pytest.mark.parametrize( From bb461c8638124121f0bc03ab03002ec008d7802b Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 28 Apr 2022 15:15:57 -0400 Subject: [PATCH 38/61] uncommented pytests --- tests/integration/test_select.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index e7fe1c3e4..a786ef0c1 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -118,9 +118,6 @@ def test_timezones(c, datetime_table): """ ) - print(f"Expected DF: \n{datetime_table.head(10)}\n") - print(f"\nResult DF: \n{result_df.head(10)}") - assert_eq(result_df, datetime_table) From f65b1abf758b34d272c1876d87639e249ad08ec8 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 28 Apr 2022 15:19:33 -0400 Subject: [PATCH 39/61] code cleanup for review --- continuous_integration/recipe/meta.yaml | 1 - dask_planner/src/expression.rs | 2 -- dask_planner/src/sql.rs | 1 - dask_planner/src/sql/logical/projection.rs | 3 --- dask_sql/mappings.py | 2 -- dask_sql/physical/rex/convert.py | 2 -- dask_sql/physical/rex/core/call.py | 3 --- 7 files changed, 14 deletions(-) diff --git a/continuous_integration/recipe/meta.yaml b/continuous_integration/recipe/meta.yaml index 6d6ef2ced..331a8ca7e 100644 --- a/continuous_integration/recipe/meta.yaml +++ b/continuous_integration/recipe/meta.yaml @@ -30,7 +30,6 @@ requirements: - setuptools-rust>=1.1.2 run: - python - - setuptools-rust>=1.1.2 - dask >=2022.3.0 - pandas >=1.0.0 - fastapi >=0.61.1 diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 40568ab75..a156b05b3 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -202,8 +202,6 @@ impl PyExpr { /// Gets the positional index of the Expr instance from the LogicalPlan DFSchema #[pyo3(name = "getIndex")] pub fn index(&self) -> PyResult { - println!("&self: {:?}", &self); - println!("&self.input_plan: {:?}", self.input_plan); let input: &Option> = &self.input_plan; match input { Some(plan) => { diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 2cbc6feef..bfade7bb9 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -158,7 +158,6 @@ impl DaskSQLContext { statement: statement::PyStatement, ) -> PyResult { let planner = SqlToRel::new(self); - println!("Statement: {:?}", statement.statement); match planner.statement_to_plan(statement.statement) { Ok(k) => { println!("\nLogicalPlan: {:?}\n\n", k); diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index b37cb3805..b7e4eb019 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -84,14 +84,11 @@ impl PyProjection { #[pyo3(name = "getNamedProjects")] fn named_projects(&mut self) -> PyResult> { let mut named: Vec<(String, PyExpr)> = Vec::new(); - println!("Projection Input: {:?}", &self.projection.input); for expression in self.projection.expr.clone() { let mut py_expr: PyExpr = PyExpr::from(expression, Some(self.projection.input.clone())); py_expr.input_plan = Some(self.projection.input.clone()); - println!("Expression Input: {:?}", &py_expr.input_plan); for expr in self.projected_expressions(&py_expr) { let name: String = self.column_name(expr.clone()).unwrap(); - println!("Named Project: {:?} - Expr: {:?}", &name, &expr); named.push((name, expr.clone())); } } diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 1d2c56ec3..36ca8d6df 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -291,7 +291,5 @@ def cast_column_to_type(col: dd.Series, expected_type: str): # will convert both NA and np.NaN to NA. col = da.trunc(col.fillna(value=np.NaN)) - print(f"Need to cast from {current_type} to {expected_type}") col = col.astype(expected_type) - print(f"col type: {col.dtype}") return col diff --git a/dask_sql/physical/rex/convert.py b/dask_sql/physical/rex/convert.py index 33e571d25..bbbeda1db 100644 --- a/dask_sql/physical/rex/convert.py +++ b/dask_sql/physical/rex/convert.py @@ -79,8 +79,6 @@ def convert( f"Processing REX {rex} using {plugin_instance.__class__.__name__}..." ) - print(f"expr_type: {expr_type} - Expr: {rex.toString()}") - df = plugin_instance.convert(rel, rex, dc, context=context) logger.debug(f"Processed REX {rex} into {LoggableDataFrame(df)}") return df diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 6486b2bd2..1a74ce9d5 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -874,9 +874,6 @@ def convert( context: "dask_sql.Context", ) -> SeriesOrScalar: logger.debug(f"Expression Operands: {expr.getOperands()}") - print( - f"Expr: {expr.toString()} - # Operands: {len(expr.getOperands())} - Operands[0]: {expr.getOperands()[0].toString()}" - ) # Prepare the operands by turning the RexNodes into python expressions operands = [ RexConverter.convert(rel, o, dc, context=context) From dc7553f8dfc0ba055da76224f1df2996e73e37ba Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 28 Apr 2022 15:21:16 -0400 Subject: [PATCH 40/61] code cleanup for review --- dask_sql/mappings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 36ca8d6df..47d8624da 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -291,5 +291,5 @@ def cast_column_to_type(col: dd.Series, expected_type: str): # will convert both NA and np.NaN to NA. col = da.trunc(col.fillna(value=np.NaN)) - col = col.astype(expected_type) - return col + logger.debug(f"Need to cast from {current_type} to {expected_type}") + return col.astype(expected_type) From f1dc0b2a90d3dd9088b3a102b540136b85f23aea Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 28 Apr 2022 15:29:41 -0400 Subject: [PATCH 41/61] Enabled more pytest that work now --- tests/integration/test_filter.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index b7f8a29f0..d22a55e08 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -14,7 +14,6 @@ def test_filter(c, df): assert_eq(return_df, expected_df) -@pytest.mark.skip(reason="WIP DataFusion") def test_filter_scalar(c, df): return_df = c.sql("SELECT * FROM df WHERE True") @@ -37,7 +36,6 @@ def test_filter_scalar(c, df): assert_eq(return_df, expected_df, check_index_type=False) -@pytest.mark.skip(reason="WIP DataFusion") def test_filter_complicated(c, df): return_df = c.sql("SELECT * FROM df WHERE a < 3 AND (b > 1 AND b < 3)") @@ -48,7 +46,6 @@ def test_filter_complicated(c, df): ) -@pytest.mark.skip(reason="WIP DataFusion") def test_filter_with_nan(c): return_df = c.sql("SELECT * FROM user_table_nan WHERE c = 3") @@ -62,7 +59,6 @@ def test_filter_with_nan(c): ) -@pytest.mark.skip(reason="WIP DataFusion") def test_string_filter(c, string_table): return_df = c.sql("SELECT * FROM string_table WHERE a = 'a normal string'") @@ -72,7 +68,6 @@ def test_string_filter(c, string_table): ) -@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize( "input_table", [ @@ -96,7 +91,6 @@ def test_filter_cast_date(c, input_table, request): assert_eq(return_df, expected_df) -@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize( "input_table", [ @@ -206,7 +200,6 @@ def test_predicate_pushdown(c, parquet_ddf, query, df_func, filters): assert_eq(return_df, expected_df, check_divisions=False) -@pytest.mark.skip(reason="WIP DataFusion") def test_filtered_csv(tmpdir, c): # Predicate pushdown is NOT supported for CSV data. # This test just checks that the "attempted" From 940e8670c2a930eb4aee2325045e21b1ecea65eb Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 28 Apr 2022 15:42:10 -0400 Subject: [PATCH 42/61] Enabled more pytest that work now --- tests/integration/test_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index d22a55e08..a128f75d5 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -72,7 +72,7 @@ def test_string_filter(c, string_table): "input_table", [ "datetime_table", - pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), + # pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), ], ) def test_filter_cast_date(c, input_table, request): From 6769ca0dab6fd6e25ab50a47743f9004ceae7e7d Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 28 Apr 2022 20:42:38 -0400 Subject: [PATCH 43/61] Output Expression as String when BinaryExpr does not contain a named alias --- dask_planner/src/expression.rs | 21 ++-- dask_planner/src/sql/logical/projection.rs | 11 +- dask_sql/physical/rel/logical/project.py | 4 + tests/integration/test_select.py | 121 ++++++++++++--------- 4 files changed, 94 insertions(+), 63 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index a156b05b3..f256fd5a6 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -112,14 +112,21 @@ impl PyExpr { Expr::ScalarVariable(..) => unimplemented!("ScalarVariable!!!"), Expr::Literal(..) => unimplemented!("Literal!!!"), Expr::BinaryExpr { - left: _, - op: _, - right: _, + left, + op, + right, } => { - // /// TODO: Examine this more deeply about whether name comes from the left or right - // self.column_name(left) - unimplemented!("BinaryExpr HERE!!!") - } + println!("left: {:?}", &left); + println!("op: {:?}", &op); + println!("right: {:?}", &right); + // If the BinaryExpr does not have an Alias + // Ex: `df.a - Int64(1)` then use the String + // representation of the Expr to match what is + // in the DFSchemaRef instance + let sample_name: String = format!("{}", &self.expr); + println!("BinaryExpr Name: {:?}", sample_name); + sample_name + }, Expr::Not(..) => unimplemented!("Not!!!"), Expr::IsNotNull(..) => unimplemented!("IsNotNull!!!"), Expr::Negative(..) => unimplemented!("Negative!!!"), diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index b7e4eb019..6933145f2 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -14,12 +14,14 @@ pub struct PyProjection { impl PyProjection { /// Projection: Gets the names of the fields that should be projected fn projected_expressions(&mut self, local_expr: &PyExpr) -> Vec { + println!("Exprs: {:?}", &self.projection.expr); + println!("Input: {:?}", &self.projection.input); + println!("Schema: {:?}", &self.projection.schema); + println!("Alias: {:?}", &self.projection.alias); let mut projs: Vec = Vec::new(); match &local_expr.expr { Expr::Alias(expr, _name) => { - let ex: Expr = *expr.clone(); - let mut py_expr: PyExpr = PyExpr::from(ex, Some(self.projection.input.clone())); - py_expr.input_plan = local_expr.input_plan.clone(); + let py_expr: PyExpr = PyExpr::from(*expr.clone(), Some(self.projection.input.clone())); projs.extend_from_slice(self.projected_expressions(&py_expr).as_slice()); } _ => projs.push(local_expr.clone()), @@ -85,8 +87,7 @@ impl PyProjection { fn named_projects(&mut self) -> PyResult> { let mut named: Vec<(String, PyExpr)> = Vec::new(); for expression in self.projection.expr.clone() { - let mut py_expr: PyExpr = PyExpr::from(expression, Some(self.projection.input.clone())); - py_expr.input_plan = Some(self.projection.input.clone()); + let py_expr: PyExpr = PyExpr::from(expression, Some(self.projection.input.clone())); for expr in self.projected_expressions(&py_expr) { let name: String = self.column_name(expr.clone()).unwrap(); named.push((name, expr.clone())); diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index 0441fe486..ba12025ab 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -38,6 +38,10 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai new_columns = {} new_mappings = {} + # Debugging only + for key, expr in named_projects: + print(f"Key: {key} - Expr: {expr.toString()}", str(key), expr) + # Collect all (new) columns this Projection will limit to for key, expr in named_projects: diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index a786ef0c1..af408ac8b 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -12,56 +12,54 @@ def test_select(c, df): assert_eq(result_df, df) -@pytest.mark.skip(reason="WIP DataFusion") -def test_select_alias(c, df): - result_df = c.sql("SELECT a as b, b as a FROM df") +# def test_select_alias(c, df): +# result_df = c.sql("SELECT a as b, b as a FROM df") - expected_df = pd.DataFrame(index=df.index) - expected_df["b"] = df.a - expected_df["a"] = df.b +# expected_df = pd.DataFrame(index=df.index) +# expected_df["b"] = df.a +# expected_df["a"] = df.b - assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) +# assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) -def test_select_column(c, df): - result_df = c.sql("SELECT a FROM df") +# def test_select_column(c, df): +# result_df = c.sql("SELECT a FROM df") - assert_eq(result_df, df[["a"]]) +# assert_eq(result_df, df[["a"]]) -def test_select_different_types(c): - expected_df = pd.DataFrame( - { - "date": pd.to_datetime(["2022-01-21 17:34", "2022-01-21", "17:34", pd.NaT]), - "string": ["this is a test", "another test", "äölüć", ""], - "integer": [1, 2, -4, 5], - "float": [-1.1, np.NaN, pd.NA, np.sqrt(2)], - } - ) - c.create_table("df", expected_df) - result_df = c.sql( - """ - SELECT * - FROM df - """ - ) +# def test_select_different_types(c): +# expected_df = pd.DataFrame( +# { +# "date": pd.to_datetime(["2022-01-21 17:34", "2022-01-21", "17:34", pd.NaT]), +# "string": ["this is a test", "another test", "äölüć", ""], +# "integer": [1, 2, -4, 5], +# "float": [-1.1, np.NaN, pd.NA, np.sqrt(2)], +# } +# ) +# c.create_table("df", expected_df) +# result_df = c.sql( +# """ +# SELECT * +# FROM df +# """ +# ) - assert_eq(result_df, expected_df) +# assert_eq(result_df, expected_df) -@pytest.mark.skip(reason="WIP DataFusion") -def test_select_expr(c, df): - result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") - result_df = result_df +# def test_select_expr(c, df): +# result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") +# result_df = result_df - expected_df = pd.DataFrame( - { - "a": df["a"] + 1, - "bla": df["b"], - '"df"."a" - 1': df["a"] - 1, - } - ) - assert_eq(result_df, expected_df) +# expected_df = pd.DataFrame( +# { +# "a": df["a"] + 1, +# "bla": df["b"], +# 'df.a - Int64(1)': df["a"] - 1, +# } +# ) +# assert_eq(result_df, expected_df) @pytest.mark.skip( @@ -103,22 +101,22 @@ def test_select_of_select_with_casing(c, df): assert_eq(result_df, expected_df) -def test_wrong_input(c): - with pytest.raises(ParsingException): - c.sql("""SELECT x FROM df""") +# def test_wrong_input(c): +# with pytest.raises(ParsingException): +# c.sql("""SELECT x FROM df""") - with pytest.raises(ParsingException): - c.sql("""SELECT x FROM df""") +# with pytest.raises(ParsingException): +# c.sql("""SELECT x FROM df""") -def test_timezones(c, datetime_table): - result_df = c.sql( - """ - SELECT * FROM datetime_table - """ - ) +# def test_timezones(c, datetime_table): +# result_df = c.sql( +# """ +# SELECT * FROM datetime_table +# """ +# ) - assert_eq(result_df, datetime_table) +# assert_eq(result_df, datetime_table) @pytest.mark.skip(reason="WIP DataFusion") @@ -144,6 +142,7 @@ def test_limit(c, input_table, limit, offset, request): assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) +<<<<<<< HEAD @pytest.mark.parametrize( "input_table", [ @@ -162,6 +161,26 @@ def test_date_casting(c, input_table, request): FROM {input_table} """ ) +======= +# @pytest.mark.parametrize( +# "input_table", +# [ +# "datetime_table", +# pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), +# ], +# ) +# def test_date_casting(c, input_table, request): +# datetime_table = request.getfixturevalue(input_table) +# result_df = c.sql( +# f""" +# SELECT +# CAST(timezone AS DATE) AS timezone, +# CAST(no_timezone AS DATE) AS no_timezone, +# CAST(utc_timezone AS DATE) AS utc_timezone +# FROM {input_table} +# """ +# ) +>>>>>>> Output Expression as String when BinaryExpr does not contain a named alias # expected_df = datetime_table # expected_df["timezone"] = ( From c4ed9bdfe7eb327ff589a55de2062363acf3a4e7 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 28 Apr 2022 20:42:52 -0400 Subject: [PATCH 44/61] Output Expression as String when BinaryExpr does not contain a named alias --- dask_planner/src/expression.rs | 10 +++------- dask_planner/src/sql/logical/projection.rs | 3 ++- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index f256fd5a6..59532e1a7 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -111,22 +111,18 @@ impl PyExpr { Expr::Column(column) => column.name.clone(), Expr::ScalarVariable(..) => unimplemented!("ScalarVariable!!!"), Expr::Literal(..) => unimplemented!("Literal!!!"), - Expr::BinaryExpr { - left, - op, - right, - } => { + Expr::BinaryExpr { left, op, right } => { println!("left: {:?}", &left); println!("op: {:?}", &op); println!("right: {:?}", &right); // If the BinaryExpr does not have an Alias // Ex: `df.a - Int64(1)` then use the String - // representation of the Expr to match what is + // representation of the Expr to match what is // in the DFSchemaRef instance let sample_name: String = format!("{}", &self.expr); println!("BinaryExpr Name: {:?}", sample_name); sample_name - }, + } Expr::Not(..) => unimplemented!("Not!!!"), Expr::IsNotNull(..) => unimplemented!("IsNotNull!!!"), Expr::Negative(..) => unimplemented!("Negative!!!"), diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index 6933145f2..796fcc319 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -21,7 +21,8 @@ impl PyProjection { let mut projs: Vec = Vec::new(); match &local_expr.expr { Expr::Alias(expr, _name) => { - let py_expr: PyExpr = PyExpr::from(*expr.clone(), Some(self.projection.input.clone())); + let py_expr: PyExpr = + PyExpr::from(*expr.clone(), Some(self.projection.input.clone())); projs.extend_from_slice(self.projected_expressions(&py_expr).as_slice()); } _ => projs.push(local_expr.clone()), From 05c5788ac17547a5b3265589ea72bb867996aeb5 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 28 Apr 2022 20:45:34 -0400 Subject: [PATCH 45/61] Disable 2 pytest that are causing gpuCI issues. They will be address in a follow up PR --- tests/integration/test_filter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index a128f75d5..e2ba74bd9 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -68,11 +68,12 @@ def test_string_filter(c, string_table): ) +@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize( "input_table", [ "datetime_table", - # pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), + pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), ], ) def test_filter_cast_date(c, input_table, request): @@ -91,6 +92,7 @@ def test_filter_cast_date(c, input_table, request): assert_eq(return_df, expected_df) +@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize( "input_table", [ From a33aa63042b7522e6df582f1ca8fe9f8bdef57ce Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Fri, 29 Apr 2022 11:28:10 -0400 Subject: [PATCH 46/61] Handle Between operation for case-when --- dask_planner/src/expression.rs | 81 +++++++++---- dask_sql/physical/rel/logical/limit.py | 8 +- dask_sql/physical/rex/core/call.py | 24 +++- tests/integration/test_select.py | 158 +++++++++++-------------- 4 files changed, 153 insertions(+), 118 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 59532e1a7..b21d1a027 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -111,17 +111,12 @@ impl PyExpr { Expr::Column(column) => column.name.clone(), Expr::ScalarVariable(..) => unimplemented!("ScalarVariable!!!"), Expr::Literal(..) => unimplemented!("Literal!!!"), - Expr::BinaryExpr { left, op, right } => { - println!("left: {:?}", &left); - println!("op: {:?}", &op); - println!("right: {:?}", &right); + Expr::BinaryExpr { .. } => { // If the BinaryExpr does not have an Alias // Ex: `df.a - Int64(1)` then use the String // representation of the Expr to match what is // in the DFSchemaRef instance - let sample_name: String = format!("{}", &self.expr); - println!("BinaryExpr Name: {:?}", sample_name); - sample_name + format!("{}", &self.expr) } Expr::Not(..) => unimplemented!("Not!!!"), Expr::IsNotNull(..) => unimplemented!("IsNotNull!!!"), @@ -129,17 +124,8 @@ impl PyExpr { Expr::GetIndexedField { .. } => unimplemented!("GetIndexedField!!!"), Expr::IsNull(..) => unimplemented!("IsNull!!!"), Expr::Between { .. } => unimplemented!("Between!!!"), - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - println!("expr: {:?}", &expr); - println!("when_then_expr: {:?}", &when_then_expr); - println!("else_expr: {:?}", &else_expr); - unimplemented!("CASE!!!") - } - Expr::Cast { .. } => unimplemented!("Cast!!!"), + Expr::Case { .. } => format!("{}", &self.expr), + Expr::Cast { .. } => format!("{}", &self.expr), Expr::TryCast { .. } => unimplemented!("TryCast!!!"), Expr::Sort { .. } => unimplemented!("Sort!!!"), Expr::ScalarFunction { .. } => unimplemented!("ScalarFunction!!!"), @@ -473,9 +459,44 @@ impl PyExpr { operands.push(py_ex); Ok(operands) } - _ => Err(PyErr::new::( - "unknown Expr type encountered", - )), + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let mut operands: Vec = Vec::new(); + match expr { + Some(e) => operands.push(PyExpr::from(*e.clone(), self.input_plan.clone())), + None => (), + }; + + for (when, then) in when_then_expr { + operands.push(PyExpr::from(*when.clone(), self.input_plan.clone())); + operands.push(PyExpr::from(*then.clone(), self.input_plan.clone())); + } + + match else_expr { + Some(e) => operands.push(PyExpr::from(*e.clone(), self.input_plan.clone())), + None => (), + } + Ok(operands) + } + Expr::Between { + expr, + negated: _, + low, + high, + } => { + let mut operands: Vec = Vec::new(); + operands.push(PyExpr::from(*expr.clone(), self.input_plan.clone())); + operands.push(PyExpr::from(*low.clone(), self.input_plan.clone())); + operands.push(PyExpr::from(*high.clone(), self.input_plan.clone())); + Ok(operands) + } + _ => Err(PyErr::new::(format!( + "unknown Expr type {:?} encountered", + &self.expr + ))), } } @@ -492,9 +513,21 @@ impl PyExpr { expr: _, data_type: _, } => Ok(String::from("cast")), - _ => Err(PyErr::new::( - "Catch all triggered ....", - )), + Expr::Between { + expr: _, + negated: _, + low: _, + high: _, + } => Ok(String::from("between")), + Expr::Case { + expr: _, + when_then_expr: _, + else_expr: _, + } => Ok(String::from("case")), + _ => Err(PyErr::new::(format!( + "Catch all triggered for get_operator_name: {:?}", + &self.expr + ))), } } diff --git a/dask_sql/physical/rel/logical/limit.py b/dask_sql/physical/rel/logical/limit.py index 58cd68fe8..76773e37e 100644 --- a/dask_sql/physical/rel/logical/limit.py +++ b/dask_sql/physical/rel/logical/limit.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: import dask_sql - from dask_sql.java import org + from dask_planner.rust import LogicalPlan class DaskLimitPlugin(BaseRelPlugin): @@ -18,11 +18,9 @@ class DaskLimitPlugin(BaseRelPlugin): (LIMIT). """ - class_name = "com.dask.sql.nodes.DaskLimit" + class_name = "Limit" - def convert( - self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" - ) -> DataContainer: + def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: (dc,) = self.assert_inputs(rel, 1, context) df = dc.df cc = dc.column_container diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 1a74ce9d5..85178e49a 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -765,6 +765,18 @@ def date_part(self, what, df: SeriesOrScalar): raise NotImplementedError(f"Extraction of {what} is not (yet) implemented.") +class BetweenOperation(Operation): + """ + Function for finding rows between two scalar values + """ + + def __init__(self): + super().__init__(self.between) + + def between(self, series: dd.Series, low, high): + return series.between(low, high, inclusive="both") + + class RexCallPlugin(BaseRexPlugin): """ RexCall is used for expressions, which calculate something. @@ -785,6 +797,7 @@ class RexCallPlugin(BaseRexPlugin): OPERATION_MAPPING = { # "binary" functions + "between": BetweenOperation(), "and": ReduceOperation(operation=operator.and_), "or": ReduceOperation(operation=operator.or_), ">": ReduceOperation(operation=operator.gt), @@ -873,20 +886,25 @@ def convert( dc: DataContainer, context: "dask_sql.Context", ) -> SeriesOrScalar: - logger.debug(f"Expression Operands: {expr.getOperands()}") + + print(f"\n\nEntering call.py convert for expr: {expr.toString()}") + + for ex in expr.getOperands(): + print(f"convert operand expr: {ex.toString()}") + # Prepare the operands by turning the RexNodes into python expressions operands = [ RexConverter.convert(rel, o, dc, context=context) for o in expr.getOperands() ] - logger.debug(f"Operands: {operands}") + print(f"\nOperands post conversion: {operands}") # Now use the operator name in the mapping # TODO: obviously this needs to not be hardcoded but not sure of the best place to pull the value from currently??? schema_name = "root" operator_name = expr.getOperatorName().lower() - logger.debug(f"Operator Name: {operator_name}") + print(f"Operator Name: {operator_name}") try: operation = self.OPERATION_MAPPING[operator_name] diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index af408ac8b..47aa55dbb 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -12,54 +12,54 @@ def test_select(c, df): assert_eq(result_df, df) -# def test_select_alias(c, df): -# result_df = c.sql("SELECT a as b, b as a FROM df") +def test_select_alias(c, df): + result_df = c.sql("SELECT a as b, b as a FROM df") -# expected_df = pd.DataFrame(index=df.index) -# expected_df["b"] = df.a -# expected_df["a"] = df.b + expected_df = pd.DataFrame(index=df.index) + expected_df["b"] = df.a + expected_df["a"] = df.b -# assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) + assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) -# def test_select_column(c, df): -# result_df = c.sql("SELECT a FROM df") +def test_select_column(c, df): + result_df = c.sql("SELECT a FROM df") -# assert_eq(result_df, df[["a"]]) + assert_eq(result_df, df[["a"]]) -# def test_select_different_types(c): -# expected_df = pd.DataFrame( -# { -# "date": pd.to_datetime(["2022-01-21 17:34", "2022-01-21", "17:34", pd.NaT]), -# "string": ["this is a test", "another test", "äölüć", ""], -# "integer": [1, 2, -4, 5], -# "float": [-1.1, np.NaN, pd.NA, np.sqrt(2)], -# } -# ) -# c.create_table("df", expected_df) -# result_df = c.sql( -# """ -# SELECT * -# FROM df -# """ -# ) +def test_select_different_types(c): + expected_df = pd.DataFrame( + { + "date": pd.to_datetime(["2022-01-21 17:34", "2022-01-21", "17:34", pd.NaT]), + "string": ["this is a test", "another test", "äölüć", ""], + "integer": [1, 2, -4, 5], + "float": [-1.1, np.NaN, pd.NA, np.sqrt(2)], + } + ) + c.create_table("df", expected_df) + result_df = c.sql( + """ + SELECT * + FROM df + """ + ) -# assert_eq(result_df, expected_df) + assert_eq(result_df, expected_df) -# def test_select_expr(c, df): -# result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") -# result_df = result_df +def test_select_expr(c, df): + result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") + result_df = result_df -# expected_df = pd.DataFrame( -# { -# "a": df["a"] + 1, -# "bla": df["b"], -# 'df.a - Int64(1)': df["a"] - 1, -# } -# ) -# assert_eq(result_df, expected_df) + expected_df = pd.DataFrame( + { + "a": df["a"] + 1, + "bla": df["b"], + "df.a - Int64(1)": df["a"] - 1, + } + ) + assert_eq(result_df, expected_df) @pytest.mark.skip( @@ -101,22 +101,22 @@ def test_select_of_select_with_casing(c, df): assert_eq(result_df, expected_df) -# def test_wrong_input(c): -# with pytest.raises(ParsingException): -# c.sql("""SELECT x FROM df""") +def test_wrong_input(c): + with pytest.raises(ParsingException): + c.sql("""SELECT x FROM df""") -# with pytest.raises(ParsingException): -# c.sql("""SELECT x FROM df""") + with pytest.raises(ParsingException): + c.sql("""SELECT x FROM df""") -# def test_timezones(c, datetime_table): -# result_df = c.sql( -# """ -# SELECT * FROM datetime_table -# """ -# ) +def test_timezones(c, datetime_table): + result_df = c.sql( + """ + SELECT * FROM datetime_table + """ + ) -# assert_eq(result_df, datetime_table) + assert_eq(result_df, datetime_table) @pytest.mark.skip(reason="WIP DataFusion") @@ -124,7 +124,7 @@ def test_select_of_select_with_casing(c, df): "input_table", [ "long_table", - pytest.param("gpu_long_table", marks=pytest.mark.gpu), + # pytest.param("gpu_long_table", marks=pytest.mark.gpu), ], ) @pytest.mark.parametrize( @@ -142,26 +142,6 @@ def test_limit(c, input_table, limit, offset, request): assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) -<<<<<<< HEAD -@pytest.mark.parametrize( - "input_table", - [ - "datetime_table", - # pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), - ], -) -def test_date_casting(c, input_table, request): - datetime_table = request.getfixturevalue(input_table) - result_df = c.sql( - f""" - SELECT - CAST(timezone AS DATE) AS timezone, - CAST(no_timezone AS DATE) AS no_timezone, - CAST(utc_timezone AS DATE) AS utc_timezone - FROM {input_table} - """ - ) -======= # @pytest.mark.parametrize( # "input_table", # [ @@ -180,23 +160,19 @@ def test_date_casting(c, input_table, request): # FROM {input_table} # """ # ) ->>>>>>> Output Expression as String when BinaryExpr does not contain a named alias -# expected_df = datetime_table -# expected_df["timezone"] = ( -# expected_df["timezone"].astype(" Date: Mon, 2 May 2022 09:37:59 -0400 Subject: [PATCH 47/61] adjust timestamp casting --- dask_planner/src/expression.rs | 6 ++++-- dask_sql/physical/rel/logical/project.py | 4 ---- dask_sql/physical/rex/core/call.py | 10 +--------- tests/integration/test_groupby.py | 1 - 4 files changed, 5 insertions(+), 16 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index b21d1a027..b3f894d58 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -97,7 +97,6 @@ impl PyExpr { } _ => { println!("Encountered a non-Aggregate type"); - name.clone().to_ascii_uppercase() } } @@ -110,7 +109,10 @@ impl PyExpr { } Expr::Column(column) => column.name.clone(), Expr::ScalarVariable(..) => unimplemented!("ScalarVariable!!!"), - Expr::Literal(..) => unimplemented!("Literal!!!"), + Expr::Literal(scalar_value) => { + println!("Scalar Value: {:?}", scalar_value); + unimplemented!("Literal!!!") + } Expr::BinaryExpr { .. } => { // If the BinaryExpr does not have an Alias // Ex: `df.a - Int64(1)` then use the String diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index ba12025ab..0441fe486 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -38,10 +38,6 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai new_columns = {} new_mappings = {} - # Debugging only - for key, expr in named_projects: - print(f"Key: {key} - Expr: {expr.toString()}", str(key), expr) - # Collect all (new) columns this Projection will limit to for key, expr in named_projects: diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 85178e49a..fc571eeae 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -235,7 +235,7 @@ def cast(self, operand, rex=None) -> SeriesOrScalar: # TODO: ideally we don't want to directly access the datetimes, # but Pandas can't truncate timezone datetimes and cuDF can't # truncate datetimes - if output_type == "DATE": + if output_type == "DATE" or output_type == "TIMESTAMP": return return_column.dt.floor("D").astype(python_type) return return_column @@ -887,24 +887,16 @@ def convert( context: "dask_sql.Context", ) -> SeriesOrScalar: - print(f"\n\nEntering call.py convert for expr: {expr.toString()}") - - for ex in expr.getOperands(): - print(f"convert operand expr: {ex.toString()}") - # Prepare the operands by turning the RexNodes into python expressions operands = [ RexConverter.convert(rel, o, dc, context=context) for o in expr.getOperands() ] - print(f"\nOperands post conversion: {operands}") - # Now use the operator name in the mapping # TODO: obviously this needs to not be hardcoded but not sure of the best place to pull the value from currently??? schema_name = "root" operator_name = expr.getOperatorName().lower() - print(f"Operator Name: {operator_name}") try: operation = self.OPERATION_MAPPING[operator_name] diff --git a/tests/integration/test_groupby.py b/tests/integration/test_groupby.py index 109074692..e69baff8b 100644 --- a/tests/integration/test_groupby.py +++ b/tests/integration/test_groupby.py @@ -22,7 +22,6 @@ def timeseries_df(c): return None -@pytest.mark.skip(reason="WIP DataFusion") def test_group_by(c): return_df = c.sql( """ From 533f50a7daf1a973a019238d621e22defef217c0 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 9 May 2022 15:07:42 -0400 Subject: [PATCH 48/61] Refactor projection _column_name() logic to the _column_name logic in expression.rs --- dask_planner/src/expression.rs | 5 +- dask_planner/src/sql/logical/projection.rs | 56 ++-------------------- 2 files changed, 6 insertions(+), 55 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 073507bb8..dd4f2d8b1 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -64,7 +64,8 @@ impl PyExpr { } } - fn _column_name(&self, plan: LogicalPlan) -> Result { + /// Determines the name of the `Expr` instance by examining the LogicalPlan + pub fn _column_name(&self, plan: &LogicalPlan) -> Result { let field = expr_to_field(&self.expr, &plan)?; Ok(field.unqualified_column().name.clone()) } @@ -203,7 +204,7 @@ impl PyExpr { /// Python friendly shim code to get the name of a column referenced by an expression pub fn column_name(&self, mut plan: logical::PyLogicalPlan) -> PyResult { - self._column_name(plan.current_node()) + self._column_name(&plan.current_node()) .map_err(|e| py_runtime_err(e)) } diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index 1b273b694..81b327d89 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -33,57 +33,6 @@ impl PyProjection { #[pymethods] impl PyProjection { - #[pyo3(name = "getColumnName")] - fn column_name(&mut self, expr: PyExpr) -> PyResult { - let mut val: String = String::from("OK"); - match expr.expr { - Expr::Alias(expr, name) => match expr.as_ref() { - Expr::Column(col) => { - let index = self.projection.input.schema().index_of_column(col).unwrap(); - match self.projection.input.as_ref() { - LogicalPlan::Aggregate(agg) => { - let mut exprs = agg.group_expr.clone(); - exprs.extend_from_slice(&agg.aggr_expr); - match &exprs[index] { - Expr::AggregateFunction { args, .. } => match &args[0] { - Expr::Column(col) => { - println!("AGGREGATE COLUMN IS {}", col.name); - val = col.name.clone(); - } - _ => unimplemented!("projection.rs column_name is unimplemented for Expr variant: {:?}", &args[0]), - }, - _ => unimplemented!("projection.rs column_name is unimplemented for Expr variant: {:?}", &exprs[index]), - } - } - LogicalPlan::TableScan(table_scan) => val = table_scan.table_name.clone(), - _ => unimplemented!("projection.rs column_name is unimplemented for LogicalPlan variant: {:?}", self.projection.input), - } - } - Expr::Cast { expr, data_type: _ } => { - let ex_type: Expr = *expr.clone(); - let py_type: PyExpr = - PyExpr::from(ex_type, Some(self.projection.input.clone())); - val = self.column_name(py_type).unwrap(); - println!("Setting col name to: {:?}", val); - } - _ => val = name.clone().to_ascii_uppercase(), - }, - Expr::Column(col) => val = col.name.clone(), - Expr::Cast { expr, data_type: _ } => { - let ex_type: Expr = *expr; - let py_type: PyExpr = PyExpr::from(ex_type, Some(self.projection.input.clone())); - val = self.column_name(py_type).unwrap() - } - _ => { - panic!( - "column_name is unimplemented for Expr variant: {:?}", - expr.expr - ); - } - } - Ok(val) - } - #[pyo3(name = "getNamedProjects")] fn named_projects(&mut self) -> PyResult> { let mut named: Vec<(String, PyExpr)> = Vec::new(); @@ -91,8 +40,9 @@ impl PyProjection { let mut py_expr: PyExpr = PyExpr::from(expression, Some(self.projection.input.clone())); py_expr.input_plan = Some(self.projection.input.clone()); for expr in self.projected_expressions(&py_expr) { - let name: String = self.column_name(expr.clone()).unwrap(); - named.push((name, expr.clone())); + if let Ok(name) = expr._column_name(&*self.projection.input) { + named.push((name, expr.clone())); + } } } Ok(named) From a42a1332c340d0942d770384643d6074e6ee355a Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 9 May 2022 15:48:37 -0400 Subject: [PATCH 49/61] removed println! statements --- dask_planner/src/expression.rs | 46 ++++++++-------------- dask_planner/src/sql/logical/join.rs | 1 - dask_planner/src/sql/logical/projection.rs | 4 -- dask_planner/src/sql/table.rs | 1 - dask_planner/src/sql/types.rs | 5 +-- 5 files changed, 18 insertions(+), 39 deletions(-) diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index dd4f2d8b1..5c725be2f 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -212,16 +212,10 @@ impl PyExpr { #[pyo3(name = "getOperands")] pub fn get_operands(&self) -> PyResult> { match &self.expr { - Expr::BinaryExpr { left, op: _, right } => { - let mut operands: Vec = Vec::new(); - let left_desc: Expr = *left.clone(); - let py_left: PyExpr = PyExpr::from(left_desc, self.input_plan.clone()); - operands.push(py_left); - let right_desc: Expr = *right.clone(); - let py_right: PyExpr = PyExpr::from(right_desc, self.input_plan.clone()); - operands.push(py_right); - Ok(operands) - } + Expr::BinaryExpr { left, op: _, right } => Ok(vec![ + PyExpr::from(*left.clone(), self.input_plan.clone()), + PyExpr::from(*right.clone(), self.input_plan.clone()), + ]), Expr::ScalarFunction { fun: _, args } => { let mut operands: Vec = Vec::new(); for arg in args { @@ -231,11 +225,7 @@ impl PyExpr { Ok(operands) } Expr::Cast { expr, data_type: _ } => { - let mut operands: Vec = Vec::new(); - let ex: Expr = *expr.clone(); - let py_ex: PyExpr = PyExpr::from(ex, self.input_plan.clone()); - operands.push(py_ex); - Ok(operands) + Ok(vec![PyExpr::from(*expr.clone(), self.input_plan.clone())]) } Expr::Case { expr, @@ -243,9 +233,9 @@ impl PyExpr { else_expr, } => { let mut operands: Vec = Vec::new(); - match expr { - Some(e) => operands.push(PyExpr::from(*e.clone(), self.input_plan.clone())), - None => (), + + if let Some(e) = expr { + operands.push(PyExpr::from(*e.clone(), self.input_plan.clone())); }; for (when, then) in when_then_expr { @@ -253,10 +243,10 @@ impl PyExpr { operands.push(PyExpr::from(*then.clone(), self.input_plan.clone())); } - match else_expr { - Some(e) => operands.push(PyExpr::from(*e.clone(), self.input_plan.clone())), - None => (), - } + if let Some(e) = else_expr { + operands.push(PyExpr::from(*e.clone(), self.input_plan.clone())); + }; + Ok(operands) } Expr::Between { @@ -264,13 +254,11 @@ impl PyExpr { negated: _, low, high, - } => { - let mut operands: Vec = Vec::new(); - operands.push(PyExpr::from(*expr.clone(), self.input_plan.clone())); - operands.push(PyExpr::from(*low.clone(), self.input_plan.clone())); - operands.push(PyExpr::from(*high.clone(), self.input_plan.clone())); - Ok(operands) - } + } => Ok(vec![ + PyExpr::from(*expr.clone(), self.input_plan.clone()), + PyExpr::from(*low.clone(), self.input_plan.clone()), + PyExpr::from(*high.clone(), self.input_plan.clone()), + ]), _ => Err(PyErr::new::(format!( "unknown Expr type {:?} encountered", &self.expr diff --git a/dask_planner/src/sql/logical/join.rs b/dask_planner/src/sql/logical/join.rs index ccb77ef6b..fbed464ca 100644 --- a/dask_planner/src/sql/logical/join.rs +++ b/dask_planner/src/sql/logical/join.rs @@ -27,7 +27,6 @@ impl PyJoin { let mut join_conditions: Vec<(column::PyColumn, column::PyColumn)> = Vec::new(); for (mut lhs, mut rhs) in self.join.on.clone() { - println!("lhs: {:?} rhs: {:?}", lhs, rhs); lhs.relation = Some(lhs_table_name.clone()); rhs.relation = Some(rhs_table_name.clone()); join_conditions.push((lhs.into(), rhs.into())); diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index 81b327d89..bbce9a137 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -14,10 +14,6 @@ pub struct PyProjection { impl PyProjection { /// Projection: Gets the names of the fields that should be projected fn projected_expressions(&mut self, local_expr: &PyExpr) -> Vec { - println!("Exprs: {:?}", &self.projection.expr); - println!("Input: {:?}", &self.projection.input); - println!("Schema: {:?}", &self.projection.schema); - println!("Alias: {:?}", &self.projection.alias); let mut projs: Vec = Vec::new(); match &local_expr.expr { Expr::Alias(expr, _name) => { diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index eebc6ff7f..10b1e7ccc 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -125,7 +125,6 @@ impl DaskTable { qualified_name.push(table_scan.table_name); } _ => { - println!("Nothing matches"); qualified_name.push(self.name.clone()); } } diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index 618786d88..2765664df 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -49,8 +49,6 @@ impl DaskTypeMap { #[new] #[args(sql_type, py_kwargs = "**")] fn new(sql_type: SqlTypeName, py_kwargs: Option<&PyDict>) -> Self { - // println!("sql_type={:?} - py_kwargs={:?}", sql_type, py_kwargs); - let d_type: DataType = match sql_type { SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE => { let (unit, tz) = match py_kwargs { @@ -206,8 +204,7 @@ impl SqlTypeName { SqlTypeName::DATE => DataType::Date64, SqlTypeName::VARCHAR => DataType::Utf8, _ => { - println!("Type: {:?}", self); - todo!(); + todo!("Type: {:?}", self); } } } From a1841c35b1294e13dc4122fc081a8dd321499665 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Wed, 11 May 2022 15:12:38 -0400 Subject: [PATCH 50/61] Updates from review --- .gitignore | 3 ++ dask_planner/src/expression.rs | 51 +++++------------------------- dask_sql/physical/rex/core/call.py | 2 +- tests/integration/test_select.py | 14 ++------ 4 files changed, 14 insertions(+), 56 deletions(-) diff --git a/.gitignore b/.gitignore index 950c92821..c25366594 100644 --- a/.gitignore +++ b/.gitignore @@ -61,3 +61,6 @@ dask-worker-space/ node_modules/ docs/source/_build/ dask_planner/Cargo.lock + +# Ignore development specific local testing files +dev_tests diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 5c725be2f..de1f56d90 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -70,9 +70,9 @@ impl PyExpr { Ok(field.unqualified_column().name.clone()) } - fn _rex_type(&self, expr: Expr) -> RexType { - match &expr { - Expr::Alias(expr, name) => RexType::Reference, + fn _rex_type(&self, expr: &Expr) -> RexType { + match expr { + Expr::Alias(..) => RexType::Reference, Expr::Column(..) => RexType::Reference, Expr::ScalarVariable(..) => RexType::Literal, Expr::Literal(..) => RexType::Literal, @@ -175,31 +175,8 @@ impl PyExpr { /// Determines the type of this Expr based on its variant #[pyo3(name = "getRexType")] - pub fn rex_type(&self) -> RexType { - match &self.expr { - Expr::Alias(..) => RexType::Reference, - Expr::Column(..) => RexType::Reference, - Expr::ScalarVariable(..) => RexType::Literal, - Expr::Literal(..) => RexType::Literal, - Expr::BinaryExpr { .. } => RexType::Call, - Expr::Not(..) => RexType::Call, - Expr::IsNotNull(..) => RexType::Call, - Expr::Negative(..) => RexType::Call, - Expr::GetIndexedField { .. } => RexType::Reference, - Expr::IsNull(..) => RexType::Call, - Expr::Between { .. } => RexType::Call, - Expr::Case { .. } => RexType::Call, - Expr::Cast { .. } => RexType::Call, - Expr::TryCast { .. } => RexType::Call, - Expr::Sort { .. } => RexType::Call, - Expr::ScalarFunction { .. } => RexType::Call, - Expr::AggregateFunction { .. } => RexType::Call, - Expr::WindowFunction { .. } => RexType::Call, - Expr::AggregateUDF { .. } => RexType::Call, - Expr::InList { .. } => RexType::Call, - Expr::Wildcard => RexType::Call, - _ => RexType::Other, - } + pub fn rex_type(&self) -> PyResult { + Ok(self._rex_type(&self.expr)) } /// Python friendly shim code to get the name of a column referenced by an expression @@ -275,21 +252,9 @@ impl PyExpr { right: _, } => Ok(format!("{}", op)), Expr::ScalarFunction { fun, args: _ } => Ok(format!("{}", fun)), - Expr::Cast { - expr: _, - data_type: _, - } => Ok(String::from("cast")), - Expr::Between { - expr: _, - negated: _, - low: _, - high: _, - } => Ok(String::from("between")), - Expr::Case { - expr: _, - when_then_expr: _, - else_expr: _, - } => Ok(String::from("case")), + Expr::Cast { .. } => Ok(String::from("cast")), + Expr::Between { .. } => Ok(String::from("between")), + Expr::Case { .. } => Ok(String::from("case")), _ => Err(PyErr::new::(format!( "Catch all triggered for get_operator_name: {:?}", &self.expr diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index fc571eeae..66c7d410a 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -235,7 +235,7 @@ def cast(self, operand, rex=None) -> SeriesOrScalar: # TODO: ideally we don't want to directly access the datetimes, # but Pandas can't truncate timezone datetimes and cuDF can't # truncate datetimes - if output_type == "DATE" or output_type == "TIMESTAMP": + if output_type == "DATE": return return_column.dt.floor("D").astype(python_type) return return_column diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index 59d54947b..b3c73b3bb 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -124,7 +124,7 @@ def test_timezones(c, datetime_table): "input_table", [ "long_table", - # pytest.param("gpu_long_table", marks=pytest.mark.gpu), + pytest.param("gpu_long_table", marks=pytest.mark.gpu), ], ) @pytest.mark.parametrize( @@ -194,17 +194,7 @@ def test_timestamp_casting(c, input_table, request): """ ) - expected_df = datetime_table - expected_df["timezone"] = ( - expected_df["timezone"].astype(" Date: Wed, 11 May 2022 19:15:30 -0400 Subject: [PATCH 51/61] Add Offset and point to repo with offset in datafusion --- dask_planner/Cargo.toml | 4 +- dask_planner/src/expression.rs | 21 +- dask_planner/src/sql.rs | 6 +- dask_planner/src/sql/logical.rs | 11 +- dask_planner/src/sql/logical/aggregate.rs | 4 +- dask_planner/src/sql/logical/filter.rs | 4 +- dask_planner/src/sql/logical/join.rs | 4 +- dask_planner/src/sql/logical/limit.rs | 42 +++ dask_planner/src/sql/logical/projection.rs | 4 +- dask_planner/src/sql/table.rs | 8 +- .../src/sql/types/rel_data_type_field.rs | 2 +- dask_sql/mappings.py | 4 +- dask_sql/physical/rel/logical/limit.py | 10 +- dask_sql/physical/rex/convert.py | 11 +- tests/integration/test_select.py | 320 +++++++++--------- 15 files changed, 257 insertions(+), 198 deletions(-) create mode 100644 dask_planner/src/sql/logical/limit.rs diff --git a/dask_planner/Cargo.toml b/dask_planner/Cargo.toml index 292c13487..2513fb5ba 100644 --- a/dask_planner/Cargo.toml +++ b/dask_planner/Cargo.toml @@ -12,11 +12,9 @@ rust-version = "1.59" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } rand = "0.7" pyo3 = { version = "0.16", features = ["extension-module", "abi3", "abi3-py38"] } -datafusion = { git="https://github.com/apache/arrow-datafusion/", rev = "5927bfceeba3a4eab7988289c674d925cc82ac05" } -datafusion-expr = { git="https://github.com/apache/arrow-datafusion/", rev = "5927bfceeba3a4eab7988289c674d925cc82ac05" } +datafusion = { git="https://github.com/jdye64/arrow-datafusion/", branch = "limit-offset" } uuid = { version = "0.8", features = ["v4"] } mimalloc = { version = "*", default-features = false } -sqlparser = "0.14.0" parking_lot = "0.12" async-trait = "0.1.41" diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index de1f56d90..76aa19fba 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -4,14 +4,14 @@ use crate::sql::types::RexType; use pyo3::prelude::*; use std::convert::From; -use datafusion::error::{DataFusionError, Result}; +use datafusion::error::Result; use datafusion::arrow::datatypes::DataType; -use datafusion_expr::{lit, BuiltinScalarFunction, Expr}; +use datafusion::logical_expr::{lit, BuiltinScalarFunction, Expr}; use datafusion::scalar::ScalarValue; -pub use datafusion_expr::LogicalPlan; +use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::Column; @@ -499,8 +499,15 @@ impl PyExpr { /// Create a [DFField] representing an [Expr], given an input [LogicalPlan] to resolve against pub fn expr_to_field(expr: &Expr, input_plan: &LogicalPlan) -> Result { - // TODO this is not the implementation that we really want and will be improved - // once some changes are made in DataFusion - let fields = exprlist_to_fields(&[expr.clone()], &input_plan.schema())?; - Ok(fields[0].clone()) + match expr { + Expr::Sort { expr, .. } => { + // DataFusion does not support create_name for sort expressions (since they never + // appear in projections) so we just delegate to the contained expression instead + expr_to_field(expr, input_plan) + } + _ => { + let fields = exprlist_to_fields(&[expr.clone()], &input_plan)?; + Ok(fields[0].clone()) + } + } } diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 4dec6f599..4e8c5813e 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -11,12 +11,13 @@ use crate::sql::exceptions::ParsingException; use datafusion::arrow::datatypes::{Field, Schema}; use datafusion::catalog::{ResolvedTableReference, TableReference}; +use datafusion::datasource::TableProvider; use datafusion::error::DataFusionError; +use datafusion::logical_expr::ScalarFunctionImplementation; use datafusion::physical_plan::udaf::AggregateUDF; use datafusion::physical_plan::udf::ScalarUDF; use datafusion::sql::parser::DFParser; use datafusion::sql::planner::{ContextProvider, SqlToRel}; -use datafusion_expr::ScalarFunctionImplementation; use std::collections::HashMap; use std::sync::Arc; @@ -55,7 +56,7 @@ impl ContextProvider for DaskSQLContext { fn get_table_provider( &self, name: TableReference, - ) -> Result, DataFusionError> { + ) -> Result, DataFusionError> { let reference: ResolvedTableReference = name.resolve(&self.default_catalog_name, &self.default_schema_name); match self.schemas.get(&self.default_schema_name) { @@ -169,6 +170,7 @@ impl DaskSQLContext { &self, statement: statement::PyStatement, ) -> PyResult { + println!("STATEMENT: {:?}", statement); let planner = SqlToRel::new(self); planner .statement_to_plan(statement.statement) diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index b6a961e67..68a22716b 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -5,9 +5,10 @@ use crate::sql::types::rel_data_type_field::RelDataTypeField; mod aggregate; mod filter; mod join; +mod limit; pub mod projection; -pub use datafusion_expr::LogicalPlan; +use datafusion::logical_expr::LogicalPlan; use datafusion::common::Result; use datafusion::prelude::Column; @@ -71,6 +72,12 @@ impl PyLogicalPlan { Ok(agg) } + /// LogicalPlan::Limit as PyLimit + pub fn limit(&self) -> PyResult { + let limit: limit::PyLimit = self.current_node.clone().unwrap().into(); + Ok(limit) + } + /// Gets the "input" for the current LogicalPlan pub fn get_inputs(&mut self) -> PyResult> { let mut py_inputs: Vec = Vec::new(); @@ -116,6 +123,7 @@ impl PyLogicalPlan { LogicalPlan::TableScan(_table_scan) => "TableScan", LogicalPlan::EmptyRelation(_empty_relation) => "EmptyRelation", LogicalPlan::Limit(_limit) => "Limit", + LogicalPlan::Offset(_offset) => "Offset", LogicalPlan::CreateExternalTable(_create_external_table) => "CreateExternalTable", LogicalPlan::CreateMemoryTable(_create_memory_table) => "CreateMemoryTable", LogicalPlan::DropTable(_drop_table) => "DropTable", @@ -127,6 +135,7 @@ impl PyLogicalPlan { LogicalPlan::SubqueryAlias(_sqalias) => "SubqueryAlias", LogicalPlan::CreateCatalogSchema(_create) => "CreateCatalogSchema", LogicalPlan::CreateCatalog(_create_catalog) => "CreateCatalog", + LogicalPlan::CreateView(_create_view) => "CreateView", }) } diff --git a/dask_planner/src/sql/logical/aggregate.rs b/dask_planner/src/sql/logical/aggregate.rs index 726a73552..522328ed0 100644 --- a/dask_planner/src/sql/logical/aggregate.rs +++ b/dask_planner/src/sql/logical/aggregate.rs @@ -1,7 +1,7 @@ use crate::expression::PyExpr; -use datafusion_expr::{logical_plan::Aggregate, Expr}; -pub use datafusion_expr::{logical_plan::JoinType, LogicalPlan}; +use datafusion::logical_expr::LogicalPlan; +use datafusion::logical_expr::{logical_plan::Aggregate, Expr}; use pyo3::prelude::*; diff --git a/dask_planner/src/sql/logical/filter.rs b/dask_planner/src/sql/logical/filter.rs index 4474ad1c6..aa8c0774d 100644 --- a/dask_planner/src/sql/logical/filter.rs +++ b/dask_planner/src/sql/logical/filter.rs @@ -1,7 +1,7 @@ use crate::expression::PyExpr; -use datafusion_expr::logical_plan::Filter; -pub use datafusion_expr::LogicalPlan; +use datafusion::logical_expr::logical_plan::Filter; +use datafusion::logical_expr::LogicalPlan; use pyo3::prelude::*; diff --git a/dask_planner/src/sql/logical/join.rs b/dask_planner/src/sql/logical/join.rs index fbed464ca..c0de6ae9a 100644 --- a/dask_planner/src/sql/logical/join.rs +++ b/dask_planner/src/sql/logical/join.rs @@ -1,7 +1,7 @@ use crate::sql::column; -use datafusion_expr::logical_plan::Join; -pub use datafusion_expr::{logical_plan::JoinType, LogicalPlan}; +use datafusion::logical_expr::logical_plan::Join; +use datafusion::logical_expr::{logical_plan::JoinType, LogicalPlan}; use pyo3::prelude::*; diff --git a/dask_planner/src/sql/logical/limit.rs b/dask_planner/src/sql/logical/limit.rs new file mode 100644 index 000000000..c97fd3025 --- /dev/null +++ b/dask_planner/src/sql/logical/limit.rs @@ -0,0 +1,42 @@ +use crate::expression::PyExpr; + +use datafusion::scalar::ScalarValue; +use pyo3::prelude::*; + +use datafusion::logical_expr::logical_plan::Limit; +use datafusion::logical_expr::{Expr, LogicalPlan}; + +#[pyclass(name = "Limit", module = "dask_planner", subclass)] +#[derive(Clone)] +pub struct PyLimit { + limit: Limit, +} + +#[pymethods] +impl PyLimit { + #[pyo3(name = "getOffset")] + pub fn limit_offset(&self) -> PyResult { + // TODO: Waiting on DataFusion issue: https://github.com/apache/arrow-datafusion/issues/2377 + Ok(PyExpr::from( + Expr::Literal(ScalarValue::UInt64(Some(0))), + Some(self.limit.input.clone()), + )) + } + + #[pyo3(name = "getFetch")] + pub fn limit_n(&self) -> PyResult { + Ok(PyExpr::from( + Expr::Literal(ScalarValue::UInt64(Some(self.limit.n.try_into().unwrap()))), + Some(self.limit.input.clone()), + )) + } +} + +impl From for PyLimit { + fn from(logical_plan: LogicalPlan) -> PyLimit { + match logical_plan { + LogicalPlan::Limit(limit) => PyLimit { limit: limit }, + _ => panic!("something went wrong here"), + } + } +} diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index bbce9a137..fd0e91fbf 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -1,7 +1,7 @@ use crate::expression::PyExpr; -pub use datafusion_expr::LogicalPlan; -use datafusion_expr::{logical_plan::Projection, Expr}; +use datafusion::logical_expr::LogicalPlan; +use datafusion::logical_expr::{logical_plan::Projection, Expr}; use pyo3::prelude::*; diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index 10b1e7ccc..8f04eeb90 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -7,10 +7,10 @@ use crate::sql::types::SqlTypeName; use async_trait::async_trait; use datafusion::arrow::datatypes::{DataType, Field, SchemaRef}; -pub use datafusion::datasource::TableProvider; +use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::DataFusionError; +use datafusion::logical_expr::{Expr, LogicalPlan, TableSource}; use datafusion::physical_plan::{empty::EmptyExec, project_schema, ExecutionPlan}; -use datafusion_expr::{Expr, LogicalPlan, TableSource}; use pyo3::prelude::*; @@ -63,6 +63,10 @@ impl TableProvider for DaskTableProvider { self.source.schema.clone() } + fn table_type(&self) -> TableType { + todo!() + } + async fn scan( &self, projection: &Option>, diff --git a/dask_planner/src/sql/types/rel_data_type_field.rs b/dask_planner/src/sql/types/rel_data_type_field.rs index befee19b8..4889c35a6 100644 --- a/dask_planner/src/sql/types/rel_data_type_field.rs +++ b/dask_planner/src/sql/types/rel_data_type_field.rs @@ -1,7 +1,7 @@ use crate::sql::types::DaskTypeMap; use crate::sql::types::SqlTypeName; -use datafusion::error::{DataFusionError, Result}; +use datafusion::error::Result; use datafusion::logical_plan::{DFField, DFSchema}; use std::fmt; diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 47d8624da..3e1c895bf 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -40,7 +40,9 @@ } if FLOAT_NAN_IMPLEMENTED: # pragma: no cover - _PYTHON_TO_SQL.update({pd.Float32Dtype(): "FLOAT", pd.Float64Dtype(): "FLOAT"}) + _PYTHON_TO_SQL.update( + {pd.Float32Dtype(): SqlTypeName.FLOAT, pd.Float64Dtype(): SqlTypeName.DOUBLE} + ) # Default mapping between SQL types and python types # for values diff --git a/dask_sql/physical/rel/logical/limit.py b/dask_sql/physical/rel/logical/limit.py index 76773e37e..bed8508e2 100644 --- a/dask_sql/physical/rel/logical/limit.py +++ b/dask_sql/physical/rel/logical/limit.py @@ -25,13 +25,15 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai df = dc.df cc = dc.column_container - offset = rel.getOffset() + limit = rel.limit() + + offset = limit.getOffset() if offset: - offset = RexConverter.convert(offset, df, context=context) + offset = RexConverter.convert(rel, offset, df, context=context) - end = rel.getFetch() + end = limit.getFetch() if end: - end = RexConverter.convert(end, df, context=context) + end = RexConverter.convert(rel, end, df, context=context) if offset: end += offset diff --git a/dask_sql/physical/rex/convert.py b/dask_sql/physical/rex/convert.py index bbbeda1db..f5eeabb58 100644 --- a/dask_sql/physical/rex/convert.py +++ b/dask_sql/physical/rex/convert.py @@ -13,16 +13,6 @@ logger = logging.getLogger(__name__) - -# _REX_TYPE_TO_PLUGIN = { -# "Alias": "InputRef", -# "Column": "InputRef", -# "BinaryExpr": "RexCall", -# "Literal": "RexLiteral", -# "ScalarFunction": "RexCall", -# "Cast": "RexCall", -# } - _REX_TYPE_TO_PLUGIN = { "RexType.Reference": "InputRef", "RexType.Call": "RexCall", @@ -66,6 +56,7 @@ def convert( using the stored plugins and the dictionary of registered dask tables. """ + print(f"convert.py invoked, rex: {rex.toString()}") expr_type = _REX_TYPE_TO_PLUGIN[str(rex.getRexType())] try: diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index b3c73b3bb..e20909c78 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -5,121 +5,119 @@ from dask_sql.utils import ParsingException from tests.utils import assert_eq +# def test_select(c, df): +# result_df = c.sql("SELECT * FROM df") -def test_select(c, df): - result_df = c.sql("SELECT * FROM df") +# assert_eq(result_df, df) - assert_eq(result_df, df) +# def test_select_alias(c, df): +# result_df = c.sql("SELECT a as b, b as a FROM df") + +# expected_df = pd.DataFrame(index=df.index) +# expected_df["b"] = df.a +# expected_df["a"] = df.b -def test_select_alias(c, df): - result_df = c.sql("SELECT a as b, b as a FROM df") +# assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) - expected_df = pd.DataFrame(index=df.index) - expected_df["b"] = df.a - expected_df["a"] = df.b - assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) +# def test_select_column(c, df): +# result_df = c.sql("SELECT a FROM df") +# assert_eq(result_df, df[["a"]]) -def test_select_column(c, df): - result_df = c.sql("SELECT a FROM df") - assert_eq(result_df, df[["a"]]) +# def test_select_different_types(c): +# expected_df = pd.DataFrame( +# { +# "date": pd.to_datetime(["2022-01-21 17:34", "2022-01-21", "17:34", pd.NaT]), +# "string": ["this is a test", "another test", "äölüć", ""], +# "integer": [1, 2, -4, 5], +# "float": [-1.1, np.NaN, pd.NA, np.sqrt(2)], +# } +# ) +# c.create_table("df", expected_df) +# result_df = c.sql( +# """ +# SELECT * +# FROM df +# """ +# ) +# assert_eq(result_df, expected_df) -def test_select_different_types(c): - expected_df = pd.DataFrame( - { - "date": pd.to_datetime(["2022-01-21 17:34", "2022-01-21", "17:34", pd.NaT]), - "string": ["this is a test", "another test", "äölüć", ""], - "integer": [1, 2, -4, 5], - "float": [-1.1, np.NaN, pd.NA, np.sqrt(2)], - } - ) - c.create_table("df", expected_df) - result_df = c.sql( - """ - SELECT * - FROM df - """ - ) - assert_eq(result_df, expected_df) +# def test_select_expr(c, df): +# result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") +# result_df = result_df +# expected_df = pd.DataFrame( +# { +# "a": df["a"] + 1, +# "bla": df["b"], +# "df.a - Int64(1)": df["a"] - 1, +# } +# ) +# assert_eq(result_df, expected_df) -def test_select_expr(c, df): - result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") - result_df = result_df - expected_df = pd.DataFrame( - { - "a": df["a"] + 1, - "bla": df["b"], - "df.a - Int64(1)": df["a"] - 1, - } - ) - assert_eq(result_df, expected_df) +# @pytest.mark.skip( +# reason="WIP DataFusion, subquery - https://github.com/apache/arrow-datafusion/issues/2237" +# ) +# def test_select_of_select(c, df): +# result_df = c.sql( +# """ +# SELECT 2*c AS e, d - 1 AS f +# FROM +# ( +# SELECT a - 1 AS c, 2*b AS d +# FROM df +# ) AS "inner" +# """ +# ) +# expected_df = pd.DataFrame({"e": 2 * (df["a"] - 1), "f": 2 * df["b"] - 1}) +# assert_eq(result_df, expected_df) -@pytest.mark.skip( - reason="WIP DataFusion, subquery - https://github.com/apache/arrow-datafusion/issues/2237" -) -def test_select_of_select(c, df): - result_df = c.sql( - """ - SELECT 2*c AS e, d - 1 AS f - FROM - ( - SELECT a - 1 AS c, 2*b AS d - FROM df - ) AS "inner" - """ - ) - - expected_df = pd.DataFrame({"e": 2 * (df["a"] - 1), "f": 2 * df["b"] - 1}) - assert_eq(result_df, expected_df) - - -@pytest.mark.skip(reason="WIP DataFusion") -def test_select_of_select_with_casing(c, df): - result_df = c.sql( - """ - SELECT AAA, aaa, aAa - FROM - ( - SELECT a - 1 AS aAa, 2*b AS aaa, a + b AS AAA - FROM df - ) AS "inner" - """ - ) - expected_df = pd.DataFrame( - {"AAA": df["a"] + df["b"], "aaa": 2 * df["b"], "aAa": df["a"] - 1} - ) +# @pytest.mark.skip(reason="WIP DataFusion") +# def test_select_of_select_with_casing(c, df): +# result_df = c.sql( +# """ +# SELECT AAA, aaa, aAa +# FROM +# ( +# SELECT a - 1 AS aAa, 2*b AS aaa, a + b AS AAA +# FROM df +# ) AS "inner" +# """ +# ) - assert_eq(result_df, expected_df) +# expected_df = pd.DataFrame( +# {"AAA": df["a"] + df["b"], "aaa": 2 * df["b"], "aAa": df["a"] - 1} +# ) +# assert_eq(result_df, expected_df) -def test_wrong_input(c): - with pytest.raises(ParsingException): - c.sql("""SELECT x FROM df""") - with pytest.raises(ParsingException): - c.sql("""SELECT x FROM df""") +# def test_wrong_input(c): +# with pytest.raises(ParsingException): +# c.sql("""SELECT x FROM df""") +# with pytest.raises(ParsingException): +# c.sql("""SELECT x FROM df""") -def test_timezones(c, datetime_table): - result_df = c.sql( - """ - SELECT * FROM datetime_table - """ - ) - assert_eq(result_df, datetime_table) +# def test_timezones(c, datetime_table): +# result_df = c.sql( +# """ +# SELECT * FROM datetime_table +# """ +# ) + +# assert_eq(result_df, datetime_table) -@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize( "input_table", [ @@ -127,9 +125,13 @@ def test_timezones(c, datetime_table): pytest.param("gpu_long_table", marks=pytest.mark.gpu), ], ) +# @pytest.mark.parametrize( +# "limit,offset", +# [(100, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)], +# ) @pytest.mark.parametrize( "limit,offset", - [(101, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)], + [(100, 0)], ) def test_limit(c, input_table, limit, offset, request): long_table = request.getfixturevalue(input_table) @@ -142,73 +144,73 @@ def test_limit(c, input_table, limit, offset, request): assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) -@pytest.mark.parametrize( - "input_table", - [ - "datetime_table", - pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), - ], -) -def test_date_casting(c, input_table, request): - datetime_table = request.getfixturevalue(input_table) - result_df = c.sql( - f""" - SELECT - CAST(timezone AS DATE) AS timezone, - CAST(no_timezone AS DATE) AS no_timezone, - CAST(utc_timezone AS DATE) AS utc_timezone - FROM {input_table} - """ - ) - - expected_df = datetime_table - expected_df["timezone"] = ( - expected_df["timezone"].astype(" Date: Thu, 12 May 2022 09:58:46 -0400 Subject: [PATCH 52/61] Introduce offset --- dask_planner/src/sql/logical/limit.rs | 8 +- dask_sql/context.py | 1 + dask_sql/physical/rel/logical/__init__.py | 2 + dask_sql/physical/rel/logical/offset.py | 108 ++++++++++++++++++++++ 4 files changed, 116 insertions(+), 3 deletions(-) create mode 100644 dask_sql/physical/rel/logical/offset.py diff --git a/dask_planner/src/sql/logical/limit.rs b/dask_planner/src/sql/logical/limit.rs index c97fd3025..61dbfaff1 100644 --- a/dask_planner/src/sql/logical/limit.rs +++ b/dask_planner/src/sql/logical/limit.rs @@ -3,8 +3,10 @@ use crate::expression::PyExpr; use datafusion::scalar::ScalarValue; use pyo3::prelude::*; -use datafusion::logical_expr::logical_plan::Limit; -use datafusion::logical_expr::{Expr, LogicalPlan}; +// use datafusion::logical_expr::logical_plan::Limit; +// use datafusion::logical_expr::{Expr, LogicalPlan}; + +use datafusion::logical_expr::{logical_plan::Limit, Expr, LogicalPlan}; #[pyclass(name = "Limit", module = "dask_planner", subclass)] #[derive(Clone)] @@ -36,7 +38,7 @@ impl From for PyLimit { fn from(logical_plan: LogicalPlan) -> PyLimit { match logical_plan { LogicalPlan::Limit(limit) => PyLimit { limit: limit }, - _ => panic!("something went wrong here"), + _ => panic!("something went wrong here!!!????"), } } } diff --git a/dask_sql/context.py b/dask_sql/context.py index ed4b3872a..de5038214 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -98,6 +98,7 @@ def __init__(self, logging_level=logging.INFO): RelConverter.add_plugin_class(logical.DaskFilterPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskJoinPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskLimitPlugin, replace=False) + RelConverter.add_plugin_class(logical.DaskOffsetPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskProjectPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskSortPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskTableScanPlugin, replace=False) diff --git a/dask_sql/physical/rel/logical/__init__.py b/dask_sql/physical/rel/logical/__init__.py index 4ce1aa365..b476e6b86 100644 --- a/dask_sql/physical/rel/logical/__init__.py +++ b/dask_sql/physical/rel/logical/__init__.py @@ -2,6 +2,7 @@ from .filter import DaskFilterPlugin from .join import DaskJoinPlugin from .limit import DaskLimitPlugin +from .offset import DaskOffsetPlugin from .project import DaskProjectPlugin from .sample import SamplePlugin from .sort import DaskSortPlugin @@ -15,6 +16,7 @@ DaskFilterPlugin, DaskJoinPlugin, DaskLimitPlugin, + DaskOffsetPlugin, DaskProjectPlugin, DaskSortPlugin, DaskTableScanPlugin, diff --git a/dask_sql/physical/rel/logical/offset.py b/dask_sql/physical/rel/logical/offset.py new file mode 100644 index 000000000..d7274394d --- /dev/null +++ b/dask_sql/physical/rel/logical/offset.py @@ -0,0 +1,108 @@ +from typing import TYPE_CHECKING + +import dask.dataframe as dd + +from dask_sql.datacontainer import DataContainer +from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.physical.rex import RexConverter +from dask_sql.physical.utils.map import map_on_partition_index + +if TYPE_CHECKING: + import dask_sql + from dask_planner.rust import LogicalPlan + + +class DaskOffsetPlugin(BaseRelPlugin): + """ + Offset is used to modify the effective expression bounds in a larger table + (OFFSET). + """ + + class_name = "Offset" + + def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: + (dc,) = self.assert_inputs(rel, 1, context) + df = dc.df + cc = dc.column_container + + limit = rel.limit() + + offset = limit.getOffset() + if offset: + offset = RexConverter.convert(rel, offset, df, context=context) + + end = limit.getFetch() + if end: + end = RexConverter.convert(rel, end, df, context=context) + + if offset: + end += offset + + df = self._apply_offset(df, offset, end) + + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + # No column type has changed, so no need to cast again + return DataContainer(df, cc) + + def _apply_offset(self, df: dd.DataFrame, offset: int, end: int) -> dd.DataFrame: + """ + Limit the dataframe to the window [offset, end]. + That is unfortunately, not so simple as we do not know how many + items we have in each partition. We have therefore no other way than to + calculate (!!!) the sizes of each partition. + + After that, we can create a new dataframe from the old + dataframe by calculating for each partition if and how much + it should be used. + We do this via generating our own dask computation graph as + we need to pass the partition number to the selection + function, which is not possible with normal "map_partitions". + """ + if not offset: + # We do a (hopefully) very quick check: if the first partition + # is already enough, we will just use this + first_partition_length = len(df.partitions[0]) + if first_partition_length >= end: + return df.head(end, compute=False) + + # First, we need to find out which partitions we want to use. + # Therefore we count the total number of entries + partition_borders = df.map_partitions(lambda x: len(x)) + + # Now we let each of the partitions figure out, how much it needs to return + # using these partition borders + # For this, we generate out own dask computation graph (as it does not really + # fit well with one of the already present methods). + + # (a) we define a method to be calculated on each partition + # This method returns the part of the partition, which falls between [offset, fetch] + # Please note that the dask object "partition_borders", will be turned into + # its pandas representation at this point and we can calculate the cumsum + # (which is not possible on the dask object). Recalculating it should not cost + # us much, as we assume the number of partitions is rather small. + def select_from_to(df, partition_index, partition_borders): + partition_borders = partition_borders.cumsum().to_dict() + this_partition_border_left = ( + partition_borders[partition_index - 1] if partition_index > 0 else 0 + ) + this_partition_border_right = partition_borders[partition_index] + + if (end and end < this_partition_border_left) or ( + offset and offset >= this_partition_border_right + ): + return df.iloc[0:0] + + from_index = max(offset - this_partition_border_left, 0) if offset else 0 + to_index = ( + min(end, this_partition_border_right) + if end + else this_partition_border_right + ) - this_partition_border_left + + return df.iloc[from_index:to_index] + + # (b) Now we just need to apply the function on every partition + # We do this via the delayed interface, which seems the easiest one. + return map_on_partition_index( + df, select_from_to, partition_borders, meta=df._meta + ) From b72917b0ea4a931fe1e30304908d53b82873dda8 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 12 May 2022 13:34:12 -0400 Subject: [PATCH 53/61] limit updates --- dask_planner/src/sql.rs | 1 - dask_planner/src/sql/logical.rs | 7 +++ dask_planner/src/sql/logical/limit.rs | 14 +---- dask_planner/src/sql/logical/offset.rs | 42 +++++++++++++ dask_sql/physical/rel/logical/limit.py | 83 ++----------------------- dask_sql/physical/rel/logical/offset.py | 6 +- tests/integration/test_select.py | 40 +++++++----- 7 files changed, 83 insertions(+), 110 deletions(-) create mode 100644 dask_planner/src/sql/logical/offset.rs diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 4e8c5813e..532fb4fef 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -170,7 +170,6 @@ impl DaskSQLContext { &self, statement: statement::PyStatement, ) -> PyResult { - println!("STATEMENT: {:?}", statement); let planner = SqlToRel::new(self); planner .statement_to_plan(statement.statement) diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index 68a22716b..f19c7470f 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -6,6 +6,7 @@ mod aggregate; mod filter; mod join; mod limit; +mod offset; pub mod projection; use datafusion::logical_expr::LogicalPlan; @@ -78,6 +79,12 @@ impl PyLogicalPlan { Ok(limit) } + /// LogicalPlan::Offset as PyOffset + pub fn offset(&self) -> PyResult { + let offset: offset::PyOffset = self.current_node.clone().unwrap().into(); + Ok(offset) + } + /// Gets the "input" for the current LogicalPlan pub fn get_inputs(&mut self) -> PyResult> { let mut py_inputs: Vec = Vec::new(); diff --git a/dask_planner/src/sql/logical/limit.rs b/dask_planner/src/sql/logical/limit.rs index 61dbfaff1..94f8bfccf 100644 --- a/dask_planner/src/sql/logical/limit.rs +++ b/dask_planner/src/sql/logical/limit.rs @@ -3,9 +3,6 @@ use crate::expression::PyExpr; use datafusion::scalar::ScalarValue; use pyo3::prelude::*; -// use datafusion::logical_expr::logical_plan::Limit; -// use datafusion::logical_expr::{Expr, LogicalPlan}; - use datafusion::logical_expr::{logical_plan::Limit, Expr, LogicalPlan}; #[pyclass(name = "Limit", module = "dask_planner", subclass)] @@ -16,16 +13,7 @@ pub struct PyLimit { #[pymethods] impl PyLimit { - #[pyo3(name = "getOffset")] - pub fn limit_offset(&self) -> PyResult { - // TODO: Waiting on DataFusion issue: https://github.com/apache/arrow-datafusion/issues/2377 - Ok(PyExpr::from( - Expr::Literal(ScalarValue::UInt64(Some(0))), - Some(self.limit.input.clone()), - )) - } - - #[pyo3(name = "getFetch")] + #[pyo3(name = "getLimitN")] pub fn limit_n(&self) -> PyResult { Ok(PyExpr::from( Expr::Literal(ScalarValue::UInt64(Some(self.limit.n.try_into().unwrap()))), diff --git a/dask_planner/src/sql/logical/offset.rs b/dask_planner/src/sql/logical/offset.rs new file mode 100644 index 000000000..ae9aee823 --- /dev/null +++ b/dask_planner/src/sql/logical/offset.rs @@ -0,0 +1,42 @@ +use crate::expression::PyExpr; + +use datafusion::scalar::ScalarValue; +use pyo3::prelude::*; + +use datafusion::logical_expr::{logical_plan::Offset, Expr, LogicalPlan}; + +#[pyclass(name = "Offset", module = "dask_planner", subclass)] +#[derive(Clone)] +pub struct PyOffset { + offset: Offset, +} + +#[pymethods] +impl PyOffset { + #[pyo3(name = "getOffset")] + pub fn offset(&self) -> PyResult { + // TODO: Waiting on DataFusion issue: https://github.com/apache/arrow-datafusion/issues/2377 + Ok(PyExpr::from( + Expr::Literal(ScalarValue::UInt64(Some(self.offset.offset as u64))), + Some(self.offset.input.clone()), + )) + } + + #[pyo3(name = "getFetch")] + pub fn offset_fetch(&self) -> PyResult { + // TODO: Still need to implement fetch size! For now get everything from offset on with '0' + Ok(PyExpr::from( + Expr::Literal(ScalarValue::UInt64(Some(0))), + Some(self.offset.input.clone()), + )) + } +} + +impl From for PyOffset { + fn from(logical_plan: LogicalPlan) -> PyOffset { + match logical_plan { + LogicalPlan::Offset(offset) => PyOffset { offset: offset }, + _ => panic!("Issue #501"), + } + } +} diff --git a/dask_sql/physical/rel/logical/limit.py b/dask_sql/physical/rel/logical/limit.py index bed8508e2..2bcaba3e1 100644 --- a/dask_sql/physical/rel/logical/limit.py +++ b/dask_sql/physical/rel/logical/limit.py @@ -1,11 +1,8 @@ from typing import TYPE_CHECKING -import dask.dataframe as dd - from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rex import RexConverter -from dask_sql.physical.utils.map import map_on_partition_index if TYPE_CHECKING: import dask_sql @@ -25,84 +22,12 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai df = dc.df cc = dc.column_container - limit = rel.limit() - - offset = limit.getOffset() - if offset: - offset = RexConverter.convert(rel, offset, df, context=context) - - end = limit.getFetch() - if end: - end = RexConverter.convert(rel, end, df, context=context) - - if offset: - end += offset + limit = RexConverter.convert(rel, rel.limit().getLimitN(), df, context=context) - df = self._apply_offset(df, offset, end) + # If an offset was present it would have already been processed at this point. + # Therefore it is always safe to start at 0 when applying the limit + df = df.iloc[:limit] cc = self.fix_column_to_row_type(cc, rel.getRowType()) # No column type has changed, so no need to cast again return DataContainer(df, cc) - - def _apply_offset(self, df: dd.DataFrame, offset: int, end: int) -> dd.DataFrame: - """ - Limit the dataframe to the window [offset, end]. - That is unfortunately, not so simple as we do not know how many - items we have in each partition. We have therefore no other way than to - calculate (!!!) the sizes of each partition. - - After that, we can create a new dataframe from the old - dataframe by calculating for each partition if and how much - it should be used. - We do this via generating our own dask computation graph as - we need to pass the partition number to the selection - function, which is not possible with normal "map_partitions". - """ - if not offset: - # We do a (hopefully) very quick check: if the first partition - # is already enough, we will just use this - first_partition_length = len(df.partitions[0]) - if first_partition_length >= end: - return df.head(end, compute=False) - - # First, we need to find out which partitions we want to use. - # Therefore we count the total number of entries - partition_borders = df.map_partitions(lambda x: len(x)) - - # Now we let each of the partitions figure out, how much it needs to return - # using these partition borders - # For this, we generate out own dask computation graph (as it does not really - # fit well with one of the already present methods). - - # (a) we define a method to be calculated on each partition - # This method returns the part of the partition, which falls between [offset, fetch] - # Please note that the dask object "partition_borders", will be turned into - # its pandas representation at this point and we can calculate the cumsum - # (which is not possible on the dask object). Recalculating it should not cost - # us much, as we assume the number of partitions is rather small. - def select_from_to(df, partition_index, partition_borders): - partition_borders = partition_borders.cumsum().to_dict() - this_partition_border_left = ( - partition_borders[partition_index - 1] if partition_index > 0 else 0 - ) - this_partition_border_right = partition_borders[partition_index] - - if (end and end < this_partition_border_left) or ( - offset and offset >= this_partition_border_right - ): - return df.iloc[0:0] - - from_index = max(offset - this_partition_border_left, 0) if offset else 0 - to_index = ( - min(end, this_partition_border_right) - if end - else this_partition_border_right - ) - this_partition_border_left - - return df.iloc[from_index:to_index] - - # (b) Now we just need to apply the function on every partition - # We do this via the delayed interface, which seems the easiest one. - return map_on_partition_index( - df, select_from_to, partition_borders, meta=df._meta - ) diff --git a/dask_sql/physical/rel/logical/offset.py b/dask_sql/physical/rel/logical/offset.py index d7274394d..9cb29818f 100644 --- a/dask_sql/physical/rel/logical/offset.py +++ b/dask_sql/physical/rel/logical/offset.py @@ -25,13 +25,13 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai df = dc.df cc = dc.column_container - limit = rel.limit() + offset_node = rel.offset() - offset = limit.getOffset() + offset = offset_node.getOffset() if offset: offset = RexConverter.convert(rel, offset, df, context=context) - end = limit.getFetch() + end = offset_node.getFetch() if end: end = RexConverter.convert(rel, end, df, context=context) diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index e20909c78..be071afeb 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -118,28 +118,40 @@ # assert_eq(result_df, datetime_table) -@pytest.mark.parametrize( - "input_table", - [ - "long_table", - pytest.param("gpu_long_table", marks=pytest.mark.gpu), - ], -) +# @pytest.mark.parametrize( +# "input_table", +# [ +# "long_table", +# pytest.param("gpu_long_table", marks=pytest.mark.gpu), +# ], +# ) +# # @pytest.mark.parametrize( +# # "limit,offset", +# # [(100, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)], +# # ) # @pytest.mark.parametrize( # "limit,offset", -# [(100, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)], +# [(100, 0)], # ) +# def test_limit(c, input_table, limit, offset, request): +# long_table = request.getfixturevalue(input_table) + +# if not limit: +# query = f"SELECT * FROM long_table OFFSET {offset}" +# else: +# query = f"SELECT * FROM long_table LIMIT {limit} OFFSET {offset}" + +# assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) + + @pytest.mark.parametrize( "limit,offset", [(100, 0)], ) -def test_limit(c, input_table, limit, offset, request): - long_table = request.getfixturevalue(input_table) +def test_limit(c, limit, offset, request): + long_table = request.getfixturevalue("long_table") - if not limit: - query = f"SELECT * FROM long_table OFFSET {offset}" - else: - query = f"SELECT * FROM long_table LIMIT {limit} OFFSET {offset}" + query = f"SELECT * FROM long_table LIMIT {limit}" assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) From 651c9ab9fdfb3eb72bba2407a559ef7d9e8506af Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Sun, 15 May 2022 15:50:50 -0400 Subject: [PATCH 54/61] commit before upstream merge --- dask_planner/src/sql.rs | 9 ++-- dask_sql/physical/rel/logical/limit.py | 2 +- dask_sql/physical/rel/logical/offset.py | 15 ++++--- tests/integration/test_select.py | 58 ++++++++++++++----------- 4 files changed, 50 insertions(+), 34 deletions(-) diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 532fb4fef..7a085d3c4 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -173,9 +173,12 @@ impl DaskSQLContext { let planner = SqlToRel::new(self); planner .statement_to_plan(statement.statement) - .map(|k| logical::PyLogicalPlan { - original_plan: k, - current_node: None, + .map(|k| { + println!("Statement: {:?}", k); + logical::PyLogicalPlan { + original_plan: k, + current_node: None, + } }) .map_err(|e| PyErr::new::(format!("{}", e))) } diff --git a/dask_sql/physical/rel/logical/limit.py b/dask_sql/physical/rel/logical/limit.py index 2bcaba3e1..04af385d8 100644 --- a/dask_sql/physical/rel/logical/limit.py +++ b/dask_sql/physical/rel/logical/limit.py @@ -26,7 +26,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai # If an offset was present it would have already been processed at this point. # Therefore it is always safe to start at 0 when applying the limit - df = df.iloc[:limit] + df = df.head(limit, npartitions=-1, compute=False) cc = self.fix_column_to_row_type(cc, rel.getRowType()) # No column type has changed, so no need to cast again diff --git a/dask_sql/physical/rel/logical/offset.py b/dask_sql/physical/rel/logical/offset.py index 9cb29818f..e78cc6bfc 100644 --- a/dask_sql/physical/rel/logical/offset.py +++ b/dask_sql/physical/rel/logical/offset.py @@ -31,15 +31,20 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai if offset: offset = RexConverter.convert(rel, offset, df, context=context) - end = offset_node.getFetch() - if end: - end = RexConverter.convert(rel, end, df, context=context) + # TODO: `Fetch` is not currently supported + # end = offset_node.getFetch() + # if end: + # end = RexConverter.convert(rel, end, df, context=context) - if offset: - end += offset + # if offset: + # end += offset + end = df.shape[0].compute() + print(f"End Size: {end} other: {len(df)}") df = self._apply_offset(df, offset, end) + print(f"Size of DF: {df.shape[0].compute()} after applying Offset: {offset}") + cc = self.fix_column_to_row_type(cc, rel.getRowType()) # No column type has changed, so no need to cast again return DataContainer(df, cc) diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index be071afeb..edaed7a77 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -118,44 +118,52 @@ # assert_eq(result_df, datetime_table) +@pytest.mark.parametrize( + "input_table", + [ + "long_table", + pytest.param("gpu_long_table", marks=pytest.mark.gpu), + ], +) # @pytest.mark.parametrize( -# "input_table", -# [ -# "long_table", -# pytest.param("gpu_long_table", marks=pytest.mark.gpu), -# ], +# "limit,offset", +# [(100, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)], # ) -# # @pytest.mark.parametrize( -# # "limit,offset", -# # [(100, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)], -# # ) # @pytest.mark.parametrize( # "limit,offset", -# [(100, 0)], +# [(100, 99), (100, 100), (101, 101)], # ) -# def test_limit(c, input_table, limit, offset, request): -# long_table = request.getfixturevalue(input_table) - -# if not limit: -# query = f"SELECT * FROM long_table OFFSET {offset}" -# else: -# query = f"SELECT * FROM long_table LIMIT {limit} OFFSET {offset}" - -# assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) - - @pytest.mark.parametrize( "limit,offset", - [(100, 0)], + [(100, 99)], ) -def test_limit(c, limit, offset, request): - long_table = request.getfixturevalue("long_table") +def test_limit(c, input_table, limit, offset, request): + long_table = request.getfixturevalue(input_table) - query = f"SELECT * FROM long_table LIMIT {limit}" + print(f"Long_Table: {long_table.shape[0]}") + + if not limit: + query = f"SELECT * FROM long_table OFFSET {offset}" + else: + query = f"SELECT * FROM long_table LIMIT {limit} OFFSET {offset}" + + print(f"Query: {query}") assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) +# @pytest.mark.parametrize( +# "limit,offset", +# [(100, 0)], +# ) +# def test_limit(c, limit, offset, request): +# long_table = request.getfixturevalue("long_table") + +# query = f"SELECT * FROM long_table LIMIT {limit} OFFSET {offset}" + +# assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) + + # @pytest.mark.parametrize( # "input_table", # [ From 3219ad013ede697d8104137ddb12d272793060bb Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 16 May 2022 09:56:19 -0400 Subject: [PATCH 55/61] Code formatting --- dask_planner/src/sql.rs | 9 +++------ dask_planner/src/sql/logical.rs | 22 ++++++++++------------ tests/integration/test_select.py | 2 +- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 4a401f40c..532fb4fef 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -173,12 +173,9 @@ impl DaskSQLContext { let planner = SqlToRel::new(self); planner .statement_to_plan(statement.statement) - .map(|k| { - // println!("Statement: {:?}", k); - logical::PyLogicalPlan { - original_plan: k, - current_node: None, - } + .map(|k| logical::PyLogicalPlan { + original_plan: k, + current_node: None, }) .map_err(|e| PyErr::new::(format!("{}", e))) } diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index d0e175030..0d008fd79 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -80,6 +80,16 @@ impl PyLogicalPlan { to_py_plan(self.current_node.as_ref()) } + /// LogicalPlan::Limit as PyLimit + pub fn limit(&self) -> PyResult { + to_py_plan(self.current_node.as_ref()) + } + + /// LogicalPlan::Offset as PyOffset + pub fn offset(&self) -> PyResult { + to_py_plan(self.current_node.as_ref()) + } + /// LogicalPlan::Projection as PyProjection pub fn projection(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) @@ -95,18 +105,6 @@ impl PyLogicalPlan { )) } - /// LogicalPlan::Limit as PyLimit - pub fn limit(&self) -> PyResult { - let limit: limit::PyLimit = self.current_node.clone().unwrap().into(); - Ok(limit) - } - - /// LogicalPlan::Offset as PyOffset - pub fn offset(&self) -> PyResult { - let offset: offset::PyOffset = self.current_node.clone().unwrap().into(); - Ok(offset) - } - /// Gets the "input" for the current LogicalPlan pub fn get_inputs(&mut self) -> PyResult> { let mut py_inputs: Vec = Vec::new(); diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index b37f48200..eba5e3608 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -127,7 +127,7 @@ def test_timezones(c, datetime_table): ) @pytest.mark.parametrize( "limit,offset", - [(100, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)], + [(101, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)], ) def test_limit(c, input_table, limit, offset, request): long_table = request.getfixturevalue(input_table) From bf91e8ff0b9e089ac9a7e96abd7e1cae161d1428 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 17 May 2022 15:08:32 -0400 Subject: [PATCH 56/61] update Cargo.toml to use Arrow-DataFusion version with LIMIT logic --- dask_planner/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_planner/Cargo.toml b/dask_planner/Cargo.toml index 2513fb5ba..5fa726082 100644 --- a/dask_planner/Cargo.toml +++ b/dask_planner/Cargo.toml @@ -12,7 +12,7 @@ rust-version = "1.59" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } rand = "0.7" pyo3 = { version = "0.16", features = ["extension-module", "abi3", "abi3-py38"] } -datafusion = { git="https://github.com/jdye64/arrow-datafusion/", branch = "limit-offset" } +datafusion = { git="https://github.com/apache/arrow-datafusion/", rev = "b1e3a521b7e0ad49c5430bc07bf3026aa2cbb231" } uuid = { version = "0.8", features = ["v4"] } mimalloc = { version = "*", default-features = false } parking_lot = "0.12" From 3dc6a893c49248f4e38f4dfb23dd5acba8f32369 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Wed, 18 May 2022 09:23:38 -0400 Subject: [PATCH 57/61] Bump DataFusion version to get changes around variant_name() --- dask_planner/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_planner/Cargo.toml b/dask_planner/Cargo.toml index 5fa726082..20ad0ab2a 100644 --- a/dask_planner/Cargo.toml +++ b/dask_planner/Cargo.toml @@ -12,7 +12,7 @@ rust-version = "1.59" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } rand = "0.7" pyo3 = { version = "0.16", features = ["extension-module", "abi3", "abi3-py38"] } -datafusion = { git="https://github.com/apache/arrow-datafusion/", rev = "b1e3a521b7e0ad49c5430bc07bf3026aa2cbb231" } +datafusion = { git="https://github.com/apache/arrow-datafusion/", rev = "78207f5092fc5204ecd791278d403dcb6f0ae683" } uuid = { version = "0.8", features = ["v4"] } mimalloc = { version = "*", default-features = false } parking_lot = "0.12" From 08b38aa348fdc3b67db7ca984815d2bd414f2de1 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Thu, 19 May 2022 08:09:48 -0400 Subject: [PATCH 58/61] Use map partitions for determining the offset --- dask_planner/src/sql/logical/offset.rs | 1 - dask_sql/physical/rel/logical/offset.py | 63 +++++++------------------ 2 files changed, 16 insertions(+), 48 deletions(-) diff --git a/dask_planner/src/sql/logical/offset.rs b/dask_planner/src/sql/logical/offset.rs index ae9aee823..c6c9adb63 100644 --- a/dask_planner/src/sql/logical/offset.rs +++ b/dask_planner/src/sql/logical/offset.rs @@ -15,7 +15,6 @@ pub struct PyOffset { impl PyOffset { #[pyo3(name = "getOffset")] pub fn offset(&self) -> PyResult { - // TODO: Waiting on DataFusion issue: https://github.com/apache/arrow-datafusion/issues/2377 Ok(PyExpr::from( Expr::Literal(ScalarValue::UInt64(Some(self.offset.offset as u64))), Some(self.offset.input.clone()), diff --git a/dask_sql/physical/rel/logical/offset.py b/dask_sql/physical/rel/logical/offset.py index e78cc6bfc..961060db7 100644 --- a/dask_sql/physical/rel/logical/offset.py +++ b/dask_sql/physical/rel/logical/offset.py @@ -5,7 +5,6 @@ from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rex import RexConverter -from dask_sql.physical.utils.map import map_on_partition_index if TYPE_CHECKING: import dask_sql @@ -31,20 +30,9 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai if offset: offset = RexConverter.convert(rel, offset, df, context=context) - # TODO: `Fetch` is not currently supported - # end = offset_node.getFetch() - # if end: - # end = RexConverter.convert(rel, end, df, context=context) - - # if offset: - # end += offset end = df.shape[0].compute() - print(f"End Size: {end} other: {len(df)}") - df = self._apply_offset(df, offset, end) - print(f"Size of DF: {df.shape[0].compute()} after applying Offset: {offset}") - cc = self.fix_column_to_row_type(cc, rel.getRowType()) # No column type has changed, so no need to cast again return DataContainer(df, cc) @@ -52,41 +40,23 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai def _apply_offset(self, df: dd.DataFrame, offset: int, end: int) -> dd.DataFrame: """ Limit the dataframe to the window [offset, end]. - That is unfortunately, not so simple as we do not know how many - items we have in each partition. We have therefore no other way than to - calculate (!!!) the sizes of each partition. - - After that, we can create a new dataframe from the old - dataframe by calculating for each partition if and how much - it should be used. - We do this via generating our own dask computation graph as - we need to pass the partition number to the selection - function, which is not possible with normal "map_partitions". + + Unfortunately, Dask does not currently support row selection through `iloc`, so this must be done using a custom partition function. + However, it is sometimes possible to compute this window using `head` when an `offset` is not specified. """ - if not offset: - # We do a (hopefully) very quick check: if the first partition - # is already enough, we will just use this - first_partition_length = len(df.partitions[0]) - if first_partition_length >= end: - return df.head(end, compute=False) - - # First, we need to find out which partitions we want to use. - # Therefore we count the total number of entries + # compute the size of each partition + # TODO: compute `cumsum` here when dask#9067 is resolved partition_borders = df.map_partitions(lambda x: len(x)) - # Now we let each of the partitions figure out, how much it needs to return - # using these partition borders - # For this, we generate out own dask computation graph (as it does not really - # fit well with one of the already present methods). - - # (a) we define a method to be calculated on each partition - # This method returns the part of the partition, which falls between [offset, fetch] - # Please note that the dask object "partition_borders", will be turned into - # its pandas representation at this point and we can calculate the cumsum - # (which is not possible on the dask object). Recalculating it should not cost - # us much, as we assume the number of partitions is rather small. - def select_from_to(df, partition_index, partition_borders): + def limit_partition_func(df, partition_borders, partition_info=None): + """Limit the partition to values contained within the specified window, returning an empty dataframe if there are none""" + + # TODO: remove the `cumsum` call here when dask#9067 is resolved partition_borders = partition_borders.cumsum().to_dict() + partition_index = ( + partition_info["number"] if partition_info is not None else 0 + ) + this_partition_border_left = ( partition_borders[partition_index - 1] if partition_index > 0 else 0 ) @@ -106,8 +76,7 @@ def select_from_to(df, partition_index, partition_borders): return df.iloc[from_index:to_index] - # (b) Now we just need to apply the function on every partition - # We do this via the delayed interface, which seems the easiest one. - return map_on_partition_index( - df, select_from_to, partition_borders, meta=df._meta + return df.map_partitions( + limit_partition_func, + partition_borders=partition_borders, ) From e1290680ba9917fffca366f3540657aa2e2a9174 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 23 May 2022 06:47:33 -0700 Subject: [PATCH 59/61] Refactor offset partition func --- dask_sql/physical/rel/logical/offset.py | 34 +++++++++---------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/dask_sql/physical/rel/logical/offset.py b/dask_sql/physical/rel/logical/offset.py index 961060db7..1eda6a71c 100644 --- a/dask_sql/physical/rel/logical/offset.py +++ b/dask_sql/physical/rel/logical/offset.py @@ -24,20 +24,17 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai df = dc.df cc = dc.column_container - offset_node = rel.offset() - - offset = offset_node.getOffset() - if offset: - offset = RexConverter.convert(rel, offset, df, context=context) + offset = RexConverter.convert( + rel, rel.offset().getOffset(), df, context=context + ) - end = df.shape[0].compute() - df = self._apply_offset(df, offset, end) + df = self._apply_offset(df, offset) cc = self.fix_column_to_row_type(cc, rel.getRowType()) # No column type has changed, so no need to cast again return DataContainer(df, cc) - def _apply_offset(self, df: dd.DataFrame, offset: int, end: int) -> dd.DataFrame: + def _apply_offset(self, df: dd.DataFrame, offset: int) -> dd.DataFrame: """ Limit the dataframe to the window [offset, end]. @@ -48,7 +45,7 @@ def _apply_offset(self, df: dd.DataFrame, offset: int, end: int) -> dd.DataFrame # TODO: compute `cumsum` here when dask#9067 is resolved partition_borders = df.map_partitions(lambda x: len(x)) - def limit_partition_func(df, partition_borders, partition_info=None): + def offset_partition_func(df, partition_borders, partition_info=None): """Limit the partition to values contained within the specified window, returning an empty dataframe if there are none""" # TODO: remove the `cumsum` call here when dask#9067 is resolved @@ -57,26 +54,19 @@ def limit_partition_func(df, partition_borders, partition_info=None): partition_info["number"] if partition_info is not None else 0 ) - this_partition_border_left = ( + partition_border_left = ( partition_borders[partition_index - 1] if partition_index > 0 else 0 ) - this_partition_border_right = partition_borders[partition_index] + partition_border_right = partition_borders[partition_index] - if (end and end < this_partition_border_left) or ( - offset and offset >= this_partition_border_right - ): + if offset >= partition_border_right: return df.iloc[0:0] - from_index = max(offset - this_partition_border_left, 0) if offset else 0 - to_index = ( - min(end, this_partition_border_right) - if end - else this_partition_border_right - ) - this_partition_border_left + from_index = max(offset - partition_border_left, 0) - return df.iloc[from_index:to_index] + return df.iloc[from_index:] return df.map_partitions( - limit_partition_func, + offset_partition_func, partition_borders=partition_borders, ) From 2d11de5798f95d0260a63866f6de1c0010fa140c Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 23 May 2022 10:17:48 -0400 Subject: [PATCH 60/61] Update to use TryFrom logic --- dask_planner/src/sql/logical/limit.rs | 11 +++++++---- dask_planner/src/sql/logical/offset.rs | 11 +++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/dask_planner/src/sql/logical/limit.rs b/dask_planner/src/sql/logical/limit.rs index 94f8bfccf..a57ba24b1 100644 --- a/dask_planner/src/sql/logical/limit.rs +++ b/dask_planner/src/sql/logical/limit.rs @@ -1,4 +1,5 @@ use crate::expression::PyExpr; +use crate::sql::exceptions::py_type_err; use datafusion::scalar::ScalarValue; use pyo3::prelude::*; @@ -22,11 +23,13 @@ impl PyLimit { } } -impl From for PyLimit { - fn from(logical_plan: LogicalPlan) -> PyLimit { +impl TryFrom for PyLimit { + type Error = PyErr; + + fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { - LogicalPlan::Limit(limit) => PyLimit { limit: limit }, - _ => panic!("something went wrong here!!!????"), + LogicalPlan::Limit(limit) => Ok(PyLimit { limit: limit }), + _ => Err(py_type_err("unexpected plan")), } } } diff --git a/dask_planner/src/sql/logical/offset.rs b/dask_planner/src/sql/logical/offset.rs index c6c9adb63..a6074dc81 100644 --- a/dask_planner/src/sql/logical/offset.rs +++ b/dask_planner/src/sql/logical/offset.rs @@ -1,4 +1,5 @@ use crate::expression::PyExpr; +use crate::sql::exceptions::py_type_err; use datafusion::scalar::ScalarValue; use pyo3::prelude::*; @@ -31,11 +32,13 @@ impl PyOffset { } } -impl From for PyOffset { - fn from(logical_plan: LogicalPlan) -> PyOffset { +impl TryFrom for PyOffset { + type Error = PyErr; + + fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { - LogicalPlan::Offset(offset) => PyOffset { offset: offset }, - _ => panic!("Issue #501"), + LogicalPlan::Offset(offset) => Ok(PyOffset { offset: offset }), + _ => Err(py_type_err("unexpected plan")), } } } From c9933777cf94cd88fe578d56aef36082cbf8899f Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 23 May 2022 08:53:06 -0700 Subject: [PATCH 61/61] Add cloudpickle to independent scheduler requirements --- .github/docker-compose.yaml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/docker-compose.yaml b/.github/docker-compose.yaml index 01cd9e3ec..21997b505 100644 --- a/.github/docker-compose.yaml +++ b/.github/docker-compose.yaml @@ -3,15 +3,19 @@ version: '3' services: dask-scheduler: container_name: dask-scheduler - image: daskdev/dask:latest + image: daskdev/dask:dev command: dask-scheduler + environment: + USE_MAMBA: "true" + EXTRA_CONDA_PACKAGES: "cloudpickle>=1.5.0" # match client cloudpickle version ports: - "8786:8786" dask-worker: container_name: dask-worker - image: daskdev/dask:latest + image: daskdev/dask:dev command: dask-worker dask-scheduler:8786 environment: + USE_MAMBA: "true" EXTRA_CONDA_PACKAGES: "pyarrow>=4.0.0" # required for parquet IO volumes: - /tmp:/tmp