Skip to content

Commit

Permalink
improve error reporting for multistatement sql
Browse files Browse the repository at this point in the history
  • Loading branch information
amitschang committed Sep 23, 2024
1 parent 29be743 commit c398ff5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
19 changes: 14 additions & 5 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use daft_plan::{LogicalPlanBuilder, LogicalPlanRef};
use sqlparser::{
ast::{
ArrayElemTypeDef, BinaryOperator, CastKind, ExactNumberInfo, GroupByExpr, Ident, Query,
SelectItem, StructField, Subscript, TableWithJoins, TimezoneInfo, UnaryOperator, Value,
WildcardAdditionalOptions,
SelectItem, Statement, StructField, Subscript, TableWithJoins, TimezoneInfo, UnaryOperator,
Value, WildcardAdditionalOptions,
},
dialect::GenericDialect,
parser::{Parser, ParserOptions},
Expand Down Expand Up @@ -88,9 +88,18 @@ impl SQLPlanner {

let statements = parser.parse_statements()?;

match statements.as_slice() {
[sqlparser::ast::Statement::Query(query)] => Ok(self.plan_query(query)?.build()),
other => unsupported_sql_err!("{}", other[0]),
match statements.len() {
1 => Ok(self.plan_statement(&statements[0])?),
other => {
unsupported_sql_err!("Only exactly one SQL statement allowed, found {}", other)
}
}
}

fn plan_statement(&mut self, statement: &Statement) -> SQLPlannerResult<LogicalPlanRef> {
match statement {
Statement::Query(query) => Ok(self.plan_query(query)?.build()),
other => unsupported_sql_err!("{}", other),
}
}

Expand Down
6 changes: 6 additions & 0 deletions tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,9 @@ def test_sql_count_star():
actual = df2.collect().to_pydict()
expected = df.agg(daft.col("b").count()).collect().to_pydict()
assert actual == expected


def test_sql_multi_statement_sql_error():
catalog = SQLCatalog({})
with pytest.raises(Exception, match="one SQL statement allowed"):
daft.sql("SELECT * FROM df; SELECT * FROM df", catalog)

0 comments on commit c398ff5

Please sign in to comment.