diff --git a/dask_planner/src/parser.rs b/dask_planner/src/parser.rs index d42cb7f59..26ea8b49e 100644 --- a/dask_planner/src/parser.rs +++ b/dask_planner/src/parser.rs @@ -24,7 +24,7 @@ pub struct CreateModel { /// model name pub name: String, /// input Query - pub select: SQLStatement, + pub select: DaskStatement, /// IF NOT EXISTS pub if_not_exists: bool, /// To replace the model or not @@ -41,7 +41,7 @@ pub struct PredictModel { /// model name pub name: String, /// input Query - pub select: SQLStatement, + pub select: DaskStatement, } /// Dask-SQL extension DDL for `CREATE SCHEMA` @@ -648,7 +648,17 @@ impl<'a> DaskParser<'a> { DaskParserUtils::elements_from_tablefactor(&self.parser.parse_table_factor()?)?; self.parser.expect_token(&Token::Comma)?; - let sql_statement = self.parser.parse_statement()?; + // Limit our input to ANALYZE, DESCRIBE, SELECT, SHOW statements + // TODO: find a more sophisticated way to allow any statement that would return a table + self.parser.expect_one_of_keywords(&[ + Keyword::SELECT, + Keyword::DESCRIBE, + Keyword::SHOW, + Keyword::ANALYZE, + ])?; + self.parser.prev_token(); + + let sql_statement = self.parse_statement()?; self.parser.expect_token(&Token::RParen)?; let predict = PredictModel { @@ -675,12 +685,27 @@ impl<'a> DaskParser<'a> { let table_factor = self.parser.parse_table_factor()?; let with_options = DaskParserUtils::options_from_tablefactor(&table_factor); - // Parse the "AS" before the SQLStatement + // Parse the nested query statement self.parser.expect_keyword(Keyword::AS)?; + self.parser.expect_token(&Token::LParen)?; + + // Limit our input to ANALYZE, DESCRIBE, SELECT, SHOW statements + // TODO: find a more sophisticated way to allow any statement that would return a table + self.parser.expect_one_of_keywords(&[ + Keyword::SELECT, + Keyword::DESCRIBE, + Keyword::SHOW, + Keyword::ANALYZE, + ])?; + self.parser.prev_token(); + + let select = self.parse_statement()?; + + self.parser.expect_token(&Token::RParen)?; let create = CreateModel { name: model_name.to_string(), - select: self.parser.parse_statement()?, + select, if_not_exists, or_replace, with_options, diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 1783f4d70..948f34839 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -381,9 +381,7 @@ impl DaskSQLContext { DaskStatement::CreateModel(create_model) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(CreateModelPlanNode { model_name: create_model.name, - input: self._logical_relational_algebra(DaskStatement::Statement(Box::new( - create_model.select, - )))?, + input: self._logical_relational_algebra(create_model.select)?, if_not_exists: create_model.if_not_exists, or_replace: create_model.or_replace, with_options: create_model.with_options, @@ -393,9 +391,7 @@ impl DaskSQLContext { node: Arc::new(PredictModelPlanNode { model_schema: predict_model.schema_name, model_name: predict_model.name, - input: self._logical_relational_algebra(DaskStatement::Statement(Box::new( - predict_model.select, - )))?, + input: self._logical_relational_algebra(predict_model.select)?, }), })), DaskStatement::DescribeModel(describe_model) => Ok(LogicalPlan::Extension(Extension { diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index ef6c41e14..8dc1aa8f8 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -170,6 +170,41 @@ def test_clustering_and_prediction(c, training_df): check_trained_model(c) +# TODO - many ML tests fail on clusters without sklearn - can we avoid this? +@skip_if_external_scheduler +def test_create_model_with_prediction(c, training_df): + c.sql( + """ + CREATE MODEL my_model1 WITH ( + model_class = 'sklearn.ensemble.GradientBoostingClassifier', + wrap_predict = True, + target_column = 'target' + ) AS ( + SELECT x, y, x*y > 0 AS target + FROM timeseries + LIMIT 100 + ) + """ + ) + + c.sql( + """ + CREATE MODEL my_model2 WITH ( + model_class = 'sklearn.ensemble.GradientBoostingClassifier', + wrap_predict = True, + target_column = 'target' + ) AS ( + SELECT * FROM PREDICT ( + MODEL my_model1, + SELECT x, y FROM timeseries LIMIT 100 + ) + ) + """ + ) + + check_trained_model(c, "my_model2") + + # TODO - many ML tests fail on clusters without sklearn - can we avoid this? @pytest.mark.skip( reason="WIP DataFusion - fails to parse ARRAY in KV pairs in WITH clause, WITH clause was previsouly ignored"