diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index bb9e60b2a..7ba4e743a 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -134,7 +134,7 @@ impl PyExpr { Expr::Case { .. } => panic!("Case!!!"), Expr::Cast { .. } => "Cast", Expr::TryCast { .. } => panic!("TryCast!!!"), - Expr::Sort { .. } => panic!("Sort!!!"), + Expr::Sort { .. } => "Sort", Expr::ScalarFunction { .. } => "ScalarFunction", Expr::AggregateFunction { .. } => "AggregateFunction", Expr::WindowFunction { .. } => panic!("WindowFunction!!!"), diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index 7be8ee0d9..b6c112c9b 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -8,6 +8,7 @@ mod explain; mod filter; mod join; pub mod projection; +mod sort; pub use datafusion::logical_expr::LogicalPlan; @@ -100,6 +101,16 @@ impl PyLogicalPlan { )) } + /// LogicalPlan::Sort as PySort + pub fn sort(&self) -> PyResult { + self.current_node + .as_ref() + .map(|plan| plan.clone().into()) + .ok_or(PyErr::new::( + "current_node was None", + )) + } + /// Gets the "input" for the current LogicalPlan pub fn get_inputs(&mut self) -> PyResult> { let mut py_inputs: Vec = Vec::new(); diff --git a/dask_planner/src/sql/logical/sort.rs b/dask_planner/src/sql/logical/sort.rs new file mode 100644 index 000000000..59292b4a3 --- /dev/null +++ b/dask_planner/src/sql/logical/sort.rs @@ -0,0 +1,71 @@ +use crate::expression::PyExpr; + +use datafusion::logical_expr::{logical_plan::Sort, Expr, LogicalPlan}; +use pyo3::prelude::*; + +#[pyclass(name = "Sort", module = "dask_planner", subclass)] +#[derive(Clone)] +pub struct PySort { + sort: Sort, +} + +impl PySort { + /// Returns if a sort expressions denotes an ascending sort + fn is_ascending(&self, expr: &Expr) -> Result { + match expr { + Expr::Sort { asc, .. } => Ok(asc.clone()), + _ => Err(PyErr::new::(format!( + "Provided Expr {:?} is not a sort type", + expr + ))), + } + } + /// Returns if nulls should be placed first in a sort expression + fn is_nulls_first(&self, expr: &Expr) -> Result { + match &expr { + Expr::Sort { nulls_first, .. } => Ok(nulls_first.clone()), + _ => Err(PyErr::new::(format!( + "Provided Expr {:?} is not a sort type", + expr + ))), + } + } +} +#[pymethods] +impl PySort { + /// Returns a Vec of the sort expressions + #[pyo3(name = "getCollation")] + pub fn sort_expressions(&self) -> PyResult> { + let mut sort_exprs: Vec = Vec::new(); + for expr in &self.sort.expr { + sort_exprs.push(PyExpr::from(expr.clone(), Some(self.sort.input.clone()))); + } + Ok(sort_exprs) + } + + #[pyo3(name = "getAscending")] + pub fn get_ascending(&self) -> PyResult> { + self.sort + .expr + .iter() + .map(|sortexpr| self.is_ascending(sortexpr)) + .collect::, _>>() + } + #[pyo3(name = "getNullsFirst")] + pub fn get_nulls_first(&self) -> PyResult> { + self.sort + .expr + .iter() + .map(|sortexpr| self.is_nulls_first(sortexpr)) + .collect::, _>>() + } +} + +impl From for PySort { + fn from(logical_plan: LogicalPlan) -> PySort { + match logical_plan { + LogicalPlan::Sort(srt) => PySort { sort: srt }, + _ => panic!("something went wrong here"), + } + } +} diff --git a/dask_sql/physical/rel/logical/sort.py b/dask_sql/physical/rel/logical/sort.py index c2e79c5e7..e31900ef9 100644 --- a/dask_sql/physical/rel/logical/sort.py +++ b/dask_sql/physical/rel/logical/sort.py @@ -2,8 +2,7 @@ from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin - -# from dask_sql.physical.utils.sort import apply_sort +from dask_sql.physical.utils.sort import apply_sort if TYPE_CHECKING: import dask_sql @@ -21,21 +20,16 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai (dc,) = self.assert_inputs(rel, 1, context) df = dc.df cc = dc.column_container - - # TODO: Commented out to pass flake8, will be fixed in sort PR - # sort_collation = rel.getCollation().getFieldCollations() - # sort_columns = [ - # cc.get_backend_by_frontend_index(int(x.getFieldIndex())) - # for x in sort_collation - # ] - - # ASCENDING = org.apache.calcite.rel.RelFieldCollation.Direction.ASCENDING - # FIRST = org.apache.calcite.rel.RelFieldCollation.NullDirection.FIRST - # sort_ascending = [x.getDirection() == ASCENDING for x in sort_collation] - # sort_null_first = [x.nullDirection == FIRST for x in sort_collation] + sort_expressions = rel.sort().getCollation() + sort_columns = [ + cc.get_backend_by_frontend_name(expr.column_name(rel)) + for expr in sort_expressions + ] + sort_ascending = rel.sort().getAscending() + sort_null_first = rel.sort().getNullsFirst() df = df.persist() - # df = apply_sort(df, sort_columns, sort_ascending, sort_null_first) + df = apply_sort(df, sort_columns, sort_ascending, sort_null_first) cc = self.fix_column_to_row_type(cc, rel.getRowType()) # No column type has changed, so no need to cast again diff --git a/tests/integration/test_sort.py b/tests/integration/test_sort.py index 3e5c71f3d..0835336ee 100644 --- a/tests/integration/test_sort.py +++ b/tests/integration/test_sort.py @@ -6,7 +6,6 @@ from tests.utils import assert_eq -@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize( "input_table_1,input_df", [ @@ -67,7 +66,6 @@ def test_sort(c, input_table_1, input_df, request): assert_eq(df_result, df_expected, check_index=False) -@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize( "input_table_1", ["user_table_1", pytest.param("gpu_user_table_1", marks=pytest.mark.gpu)], @@ -90,7 +88,6 @@ def test_sort_by_alias(c, input_table_1, request): assert_eq(df_result, df_expected, check_index=False) -@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_sort_with_nan(gpu): c = Context() @@ -181,7 +178,6 @@ def test_sort_with_nan(gpu): ) -@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_sort_with_nan_more_columns(gpu): c = Context() @@ -240,7 +236,6 @@ def test_sort_with_nan_more_columns(gpu): ) -@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_sort_with_nan_many_partitions(gpu): c = Context() @@ -281,7 +276,6 @@ def test_sort_with_nan_many_partitions(gpu): ) -@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_sort_strings(c, gpu): string_table = pd.DataFrame({"a": ["zzhsd", "öfjdf", "baba"]}) @@ -301,7 +295,6 @@ def test_sort_strings(c, gpu): assert_eq(df_result, df_expected, check_index=False) -@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_sort_not_allowed(c, gpu): table_name = "gpu_user_table_1" if gpu else "user_table_1" @@ -309,3 +302,53 @@ def test_sort_not_allowed(c, gpu): # Wrong column with pytest.raises(Exception): c.sql(f"SELECT * FROM {table_name} ORDER BY 42") + + +@pytest.mark.xfail(Reason="Projection step before sort currently failing") +@pytest.mark.parametrize( + "input_table_1", + ["user_table_1", pytest.param("gpu_user_table_1", marks=pytest.mark.gpu)], +) +def test_sort_by_old_alias(c, input_table_1, request): + user_table_1 = request.getfixturevalue(input_table_1) + + df_result = c.sql( + f""" + SELECT + b AS my_column + FROM {input_table_1} + ORDER BY b, user_id DESC + """ + ).rename(columns={"my_column": "b"}) + df_expected = user_table_1.sort_values(["b", "user_id"], ascending=[True, False])[ + ["b"] + ] + + assert_eq(df_result, df_expected, check_index=False) + + df_result = c.sql( + f""" + SELECT + b*-1 AS my_column + FROM {input_table_1} + ORDER BY b, user_id DESC + """ + ).rename(columns={"my_column": "b"}) + df_expected = user_table_1.sort_values(["b", "user_id"], ascending=[True, False])[ + ["b"] + ] + df_expected["b"] *= -1 + assert_eq(df_result, df_expected, check_index=False) + + df_result = c.sql( + f""" + SELECT + b*-1 AS my_column + FROM {input_table_1} + ORDER BY my_column, user_id DESC + """ + ).rename(columns={"my_column": "b"}) + df_expected["b"] *= -1 + df_expected = user_table_1.sort_values(["b", "user_id"], ascending=[True, False])[ + ["b"] + ]