From 9bb4b3a86d07cae3189d17aa909511679ca8f104 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Fri, 2 Aug 2024 09:55:29 -0500 Subject: [PATCH] [FEAT]: sql case/when (#2591) --- src/daft-sql/src/analyzer.rs | 206 ----------------------------------- src/daft-sql/src/lib.rs | 1 + src/daft-sql/src/planner.rs | 30 ++++- tests/sql/test_sql.py | 32 ++++++ 4 files changed, 62 insertions(+), 207 deletions(-) delete mode 100644 src/daft-sql/src/analyzer.rs diff --git a/src/daft-sql/src/analyzer.rs b/src/daft-sql/src/analyzer.rs deleted file mode 100644 index 28f44ce395..0000000000 --- a/src/daft-sql/src/analyzer.rs +++ /dev/null @@ -1,206 +0,0 @@ -use common_error::DaftError; -use common_error::DaftResult; -use daft_plan::LogicalPlanBuilder; -use sqlparser::ast::*; - -use crate::catalog::Catalog; - -#[macro_export] -macro_rules! not_supported { - ($($arg:tt)*) => {{ - let msg = format!($($arg)*); - Err(DaftError::InternalError(msg.to_string())) - }} -} - -/// Context will hold common table expressions (CTEs) and other scoped context information. -pub struct Context {} - -/// The `Analyzer` is responsible for, -/// -/// 1. Normalization -/// 2. Semantic analysis -/// 3. Logical transformation -pub struct Analyzer { - catalog: Catalog, -} - -impl Analyzer { - pub fn new(catalog: Catalog) -> Self { - Analyzer { catalog } - } - - pub fn analyze(&mut self, statement: Statement) -> DaftResult { - if let Statement::Query(query) = statement { - self.analyze_query(*query) - } else { - not_supported!("Statement not supported: {:?}", statement) - } - } - - fn analyze_query(&mut self, query: Query) -> DaftResult { - // initialize context for this query - let context = Context {}; - - // add CTEs to context - if let Some(with) = query.with { - self.analyze_with(with, &context)?; - } - - // analyze the query body - let builder = match *(query.body) { - SetExpr::Select(select) => self.analyze_select(*select, context)?, - _ => return not_supported!("Query not supported: {:?}", query.body), - }; - - // apply limit and offset - let builder = self.analyze_fetch(builder, query.limit, query.offset)?; - - // return a query action with the prepared builder - Ok(builder) - } - - /// TODO push all defined tables into analyzer state. - fn analyze_with(&mut self, _with: With, _context: &Context) -> DaftResult<()> { - not_supported!("WITH not supported") - } - - /// The SELECT-FROM-WHERE AST node. - /// - /// References - /// - https://docs.rs/sqlparser/latest/sqlparser/ast/struct.Select.html - /// - https://github.com/apache/datafusion/blob/main/datafusion/sql/src/select.rs#L50 - /// - fn analyze_select( - &mut self, - select: Select, - _context: Context, - ) -> DaftResult { - // UNSUPPORTED FEATURES - if !select.cluster_by.is_empty() { - return not_supported!("CLUSTER BY"); - } - if !select.lateral_views.is_empty() { - return not_supported!("LATERAL VIEWS"); - } - if select.qualify.is_some() { - return not_supported!("QUALIFY"); - } - if select.top.is_some() { - return not_supported!("TOP"); - } - if !select.sort_by.is_empty() { - return not_supported!("SORT BY"); - } - - // FROM - let builder = self.analyze_tables(select.from, _context)?; - - // WHERE - if select.selection.is_some() { - return not_supported!("WHERE clause not supported"); - } - - // SELECT - self.analyze_select_list(builder, select.projection) - } - - /// Produce a relation from the tables. - fn analyze_tables( - &self, - mut tables: Vec, - _context: Context, - ) -> DaftResult { - match tables.len() { - 0 => not_supported!("SELECT without FROM not supported."), - 1 => self.analyze_table(tables.remove(0), _context), - _ => not_supported!("CROSS JOIN not supported"), - } - } - - fn analyze_select_list( - &mut self, - builder: LogicalPlanBuilder, - select: Vec, - ) -> DaftResult { - let mut had_star = false; - for item in select { - if had_star { - // Consider EXCLUDE support - return not_supported!("Multiple * in SELECT"); - } - match item { - SelectItem::Wildcard(_) => { - had_star = true; - } - SelectItem::UnnamedExpr(_) => { - return not_supported!("SELECT not supported"); - } - SelectItem::ExprWithAlias { expr: _, alias: _ } => { - return not_supported!("SELECT AS not supported"); - } - SelectItem::QualifiedWildcard(_, _) => { - return not_supported!("SELECT .* not supported"); - } - } - } - if had_star { - Ok(builder) - } else { - not_supported!("SELECT not supported, only SELECT *") - } - } - - fn analyze_table( - &self, - table: TableWithJoins, - _context: Context, - ) -> DaftResult { - if !table.joins.is_empty() { - return not_supported!("JOIN is not supported"); - } - self.analyze_relation(table.relation, _context) - } - - fn analyze_relation( - &self, - relation: TableFactor, - _context: Context, - ) -> DaftResult { - match relation { - TableFactor::Table { name, alias, .. } => { - if alias.is_some() { - return not_supported!("
AS alias not supported"); - } - let name = name.0.first().unwrap().value.to_string(); - let plan = self.catalog.get_table(&name); - let plan = match plan { - Some(plan) => LogicalPlanBuilder::new(plan), - None => return not_supported!("Table not found: {}", name), - }; - Ok(plan) - } - TableFactor::Derived { .. } => not_supported!("Derived table"), - TableFactor::TableFunction { .. } => not_supported!("Table function"), - TableFactor::Function { .. } => not_supported!("Function"), - TableFactor::UNNEST { .. } => not_supported!("UNNEST"), - TableFactor::JsonTable { .. } => not_supported!("JsonTable"), - TableFactor::NestedJoin { .. } => not_supported!("NestedJoin"), - TableFactor::Pivot { .. } => not_supported!("Pivot"), - TableFactor::Unpivot { .. } => not_supported!("Unpivot"), - TableFactor::MatchRecognize { .. } => not_supported!("MatchRecognize"), - } - } - - fn analyze_fetch( - &mut self, - builder: LogicalPlanBuilder, - limit: Option, - _offset: Option, - ) -> DaftResult { - if limit.is_some() { - return not_supported!("LIMIT is not supported"); - } - Ok(builder) - } -} diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 0cd959232f..bf9cdbb22a 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -147,6 +147,7 @@ mod tests { #[case::orderby("select * from tbl1 order by i32 desc")] #[case::orderby("select * from tbl1 order by i32 asc")] #[case::orderby_multi("select * from tbl1 order by i32 desc, f32 asc")] + #[case::whenthen("select case when i32 = 1 then 'a' else 'b' end from tbl1")] fn test_compiles(#[case] query: &str) -> SQLPlannerResult<()> { let planner = setup(); diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 1a94c98c0a..76a7a22343 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -574,7 +574,35 @@ impl SQLPlanner { SQLExpr::TypedString { .. } => unsupported_sql_err!("TYPED STRING"), SQLExpr::MapAccess { .. } => unsupported_sql_err!("MAP ACCESS"), SQLExpr::Function(func) => self.plan_function(func, current_relation), - SQLExpr::Case { .. } => unsupported_sql_err!("CASE"), + SQLExpr::Case { + operand, + conditions, + results, + else_result, + } => { + if operand.is_some() { + unsupported_sql_err!("CASE with operand not yet supported"); + } + if results.len() != conditions.len() { + unsupported_sql_err!("CASE with different number of conditions and results"); + } + + let else_expr = match else_result { + Some(expr) => self.plan_expr(expr, current_relation)?, + None => unsupported_sql_err!("CASE with no else result"), + }; + + // we need to traverse from back to front to build the if else chain + // because we need to start with the else expression + conditions.iter().zip(results.iter()).rev().try_fold( + else_expr, + |else_expr, (condition, result)| { + let cond = self.plan_expr(condition, current_relation)?; + let res = self.plan_expr(result, current_relation)?; + Ok(cond.if_else(res, else_expr)) + }, + ) + } SQLExpr::Exists { .. } => unsupported_sql_err!("EXISTS"), SQLExpr::Subquery(_) => unsupported_sql_err!("SUBQUERY"), SQLExpr::GroupingSets(_) => unsupported_sql_err!("GROUPING SETS"), diff --git a/tests/sql/test_sql.py b/tests/sql/test_sql.py index c321fdf7e4..eb04908349 100644 --- a/tests/sql/test_sql.py +++ b/tests/sql/test_sql.py @@ -1,5 +1,6 @@ import os +import numpy as np import pytest import daft @@ -48,3 +49,34 @@ def test_parse_ok(name, sql): print(name) print(sql) print("--------------") + + +def test_fizzbuzz_sql(): + arr = np.arange(100) + df = daft.from_pydict({"a": arr}) + catalog = SQLCatalog({"test": df}) + # test case expression + expected = daft.from_pydict( + { + "a": arr, + "fizzbuzz": [ + "FizzBuzz" if x % 15 == 0 else "Fizz" if x % 3 == 0 else "Buzz" if x % 5 == 0 else str(x) + for x in range(0, 100) + ], + } + ).collect() + df = daft.sql( + """ + SELECT + a, + CASE + WHEN a % 15 = 0 THEN 'FizzBuzz' + WHEN a % 3 = 0 THEN 'Fizz' + WHEN a % 5 = 0 THEN 'Buzz' + ELSE CAST(a AS TEXT) + END AS fizzbuzz + FROM test + """, + catalog=catalog, + ).collect() + assert df.to_pydict() == expected.to_pydict()