Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions dask_planner/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct CreateModel {
/// model name
pub name: String,
/// input Query
pub select: SQLStatement,
pub select: DaskStatement,
/// IF NOT EXISTS
pub if_not_exists: bool,
/// To replace the model or not
Expand All @@ -41,7 +41,7 @@ pub struct PredictModel {
/// model name
pub name: String,
/// input Query
pub select: SQLStatement,
pub select: DaskStatement,
}

/// Dask-SQL extension DDL for `CREATE SCHEMA`
Expand Down Expand Up @@ -648,7 +648,17 @@ impl<'a> DaskParser<'a> {
DaskParserUtils::elements_from_tablefactor(&self.parser.parse_table_factor()?)?;
self.parser.expect_token(&Token::Comma)?;

let sql_statement = self.parser.parse_statement()?;
// Limit our input to ANALYZE, DESCRIBE, SELECT, SHOW statements
// TODO: find a more sophisticated way to allow any statement that would return a table
self.parser.expect_one_of_keywords(&[
Keyword::SELECT,
Keyword::DESCRIBE,
Keyword::SHOW,
Keyword::ANALYZE,
])?;
self.parser.prev_token();

let sql_statement = self.parse_statement()?;
self.parser.expect_token(&Token::RParen)?;

let predict = PredictModel {
Expand All @@ -675,12 +685,27 @@ impl<'a> DaskParser<'a> {
let table_factor = self.parser.parse_table_factor()?;
let with_options = DaskParserUtils::options_from_tablefactor(&table_factor);

// Parse the "AS" before the SQLStatement
// Parse the nested query statement
self.parser.expect_keyword(Keyword::AS)?;
self.parser.expect_token(&Token::LParen)?;

// Limit our input to ANALYZE, DESCRIBE, SELECT, SHOW statements
// TODO: find a more sophisticated way to allow any statement that would return a table
self.parser.expect_one_of_keywords(&[
Keyword::SELECT,
Keyword::DESCRIBE,
Keyword::SHOW,
Keyword::ANALYZE,
])?;
self.parser.prev_token();

let select = self.parse_statement()?;

self.parser.expect_token(&Token::RParen)?;

let create = CreateModel {
name: model_name.to_string(),
select: self.parser.parse_statement()?,
select,
if_not_exists,
or_replace,
with_options,
Expand Down
8 changes: 2 additions & 6 deletions dask_planner/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,7 @@ impl DaskSQLContext {
DaskStatement::CreateModel(create_model) => Ok(LogicalPlan::Extension(Extension {
node: Arc::new(CreateModelPlanNode {
model_name: create_model.name,
input: self._logical_relational_algebra(DaskStatement::Statement(Box::new(
create_model.select,
)))?,
input: self._logical_relational_algebra(create_model.select)?,
if_not_exists: create_model.if_not_exists,
or_replace: create_model.or_replace,
with_options: create_model.with_options,
Expand All @@ -393,9 +391,7 @@ impl DaskSQLContext {
node: Arc::new(PredictModelPlanNode {
model_schema: predict_model.schema_name,
model_name: predict_model.name,
input: self._logical_relational_algebra(DaskStatement::Statement(Box::new(
predict_model.select,
)))?,
input: self._logical_relational_algebra(predict_model.select)?,
}),
})),
DaskStatement::DescribeModel(describe_model) => Ok(LogicalPlan::Extension(Extension {
Expand Down
35 changes: 35 additions & 0 deletions tests/integration/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,41 @@ def test_clustering_and_prediction(c, training_df):
check_trained_model(c)


# TODO - many ML tests fail on clusters without sklearn - can we avoid this?
@skip_if_external_scheduler
def test_create_model_with_prediction(c, training_df):
c.sql(
"""
CREATE MODEL my_model1 WITH (
model_class = 'sklearn.ensemble.GradientBoostingClassifier',
wrap_predict = True,
target_column = 'target'
) AS (
SELECT x, y, x*y > 0 AS target
FROM timeseries
LIMIT 100
)
"""
)

c.sql(
"""
CREATE MODEL my_model2 WITH (
model_class = 'sklearn.ensemble.GradientBoostingClassifier',
wrap_predict = True,
target_column = 'target'
) AS (
SELECT * FROM PREDICT (
MODEL my_model1,
SELECT x, y FROM timeseries LIMIT 100
)
)
"""
)

check_trained_model(c, "my_model2")


# TODO - many ML tests fail on clusters without sklearn - can we avoid this?
@pytest.mark.skip(
reason="WIP DataFusion - fails to parse ARRAY in KV pairs in WITH clause, WITH clause was previsouly ignored"
Expand Down