diff --git a/src/databricks/labs/ucx/source_code/linters/dbfs.py b/src/databricks/labs/ucx/source_code/linters/dbfs.py index 21c7f012c0..ce637958d8 100644 --- a/src/databricks/labs/ucx/source_code/linters/dbfs.py +++ b/src/databricks/labs/ucx/source_code/linters/dbfs.py @@ -2,10 +2,10 @@ from collections.abc import Iterable from astroid import Call, Const, InferenceError, NodeNG # type: ignore -import sqlglot +from sqlglot import Expression, parse as parse_sql, ParseError as SqlParseError from sqlglot.expressions import Table -from databricks.labs.ucx.source_code.base import Advice, Linter, Deprecation, CurrentSessionState +from databricks.labs.ucx.source_code.base import Advice, Linter, Deprecation, CurrentSessionState, Failure from databricks.labs.ucx.source_code.linters.python_ast import Tree, TreeVisitor, InferredValue logger = logging.getLogger(__name__) @@ -95,12 +95,28 @@ def name() -> str: return 'dbfs-query' def lint(self, code: str) -> Iterable[Advice]: - for statement in sqlglot.parse(code, read='databricks'): - if not statement: - continue - for table in statement.find_all(Table): - # Check table names for deprecated DBFS table names - yield from self._check_dbfs_folder(table) + try: + queries = parse_sql(code, read='databricks') + for query in queries: + if not query: + continue + yield from self._lint_query(query) + except SqlParseError as e: + logger.debug(f"Failed to parse SQL: {code}", exc_info=e) + yield Failure( + code='dbfs-query', + message=f"SQL query is not supported yet: {code}", + # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 + start_line=0, + start_col=0, + end_line=0, + end_col=1024, + ) + + def _lint_query(self, query: Expression): + for table in query.find_all(Table): + # Check table names for deprecated DBFS table names + yield from self._check_dbfs_folder(table) def _check_dbfs_folder(self, table: Table) -> Iterable[Advice]: """ diff --git a/src/databricks/labs/ucx/source_code/queries.py b/src/databricks/labs/ucx/source_code/queries.py index d465368889..b51c9e4853 100644 --- a/src/databricks/labs/ucx/source_code/queries.py +++ b/src/databricks/labs/ucx/source_code/queries.py @@ -1,10 +1,10 @@ from collections.abc import Iterable import logging -import sqlglot +from sqlglot import parse as parse_sql, ParseError as SqlParseError from sqlglot.expressions import Table, Expression, Use, Create from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex -from databricks.labs.ucx.source_code.base import Advice, Deprecation, Fixer, Linter, CurrentSessionState +from databricks.labs.ucx.source_code.base import Advice, Deprecation, Fixer, Linter, CurrentSessionState, Failure logger = logging.getLogger(__name__) @@ -43,41 +43,57 @@ def schema(self): return self._session_state.schema def lint(self, code: str) -> Iterable[Advice]: - for statement in sqlglot.parse(code, read='databricks'): - if not statement: - continue - for table in statement.find_all(Table): - if isinstance(statement, Use): - # Sqlglot captures the database name in the Use statement as a Table, with - # the schema as the table name. - self._session_state.schema = table.name - continue - if isinstance(statement, Create) and statement.kind == "SCHEMA": - # Sqlglot captures the schema name in the Create statement as a Table, with - # the schema as the db name. - self._session_state.schema = table.db + try: + statements = parse_sql(code, read='databricks') + for statement in statements: + if not statement: continue + yield from self._lint_statement(statement) + except SqlParseError as e: + logger.debug(f"Failed to parse SQL: {code}", exc_info=e) + yield Failure( + code='table-migrate', + message=f"SQL query is not supported yet: {code}", + # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 + start_line=0, + start_col=0, + end_line=0, + end_col=1024, + ) + + def _lint_statement(self, statement: Expression): + for table in statement.find_all(Table): + if isinstance(statement, Use): + # Sqlglot captures the database name in the Use statement as a Table, with + # the schema as the table name. + self._session_state.schema = table.name + continue + if isinstance(statement, Create) and statement.kind == "SCHEMA": + # Sqlglot captures the schema name in the Create statement as a Table, with + # the schema as the db name. + self._session_state.schema = table.db + continue - # we only migrate tables in the hive_metastore catalog - if self._catalog(table) != 'hive_metastore': - continue - # Sqlglot uses db instead of schema, watch out for that - src_schema = table.db if table.db else self._session_state.schema - if not src_schema: - logger.error(f"Could not determine schema for table {table.name}") - continue - dst = self._index.get(src_schema, table.name) - if not dst: - continue - yield Deprecation( - code='table-migrate', - message=f"Table {src_schema}.{table.name} is migrated to {dst.destination()} in Unity Catalog", - # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 - start_line=0, - start_col=0, - end_line=0, - end_col=1024, - ) + # we only migrate tables in the hive_metastore catalog + if self._catalog(table) != 'hive_metastore': + continue + # Sqlglot uses db instead of schema, watch out for that + src_schema = table.db if table.db else self._session_state.schema + if not src_schema: + logger.error(f"Could not determine schema for table {table.name}") + continue + dst = self._index.get(src_schema, table.name) + if not dst: + continue + yield Deprecation( + code='table-migrate', + message=f"Table {src_schema}.{table.name} is migrated to {dst.destination()} in Unity Catalog", + # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 + start_line=0, + start_col=0, + end_line=0, + end_col=1024, + ) @staticmethod def _catalog(table): @@ -87,7 +103,7 @@ def _catalog(table): def apply(self, code: str) -> str: new_statements = [] - for statement in sqlglot.parse(code, read='databricks'): + for statement in parse_sql(code, read='databricks'): if not statement: continue if isinstance(statement, Use): diff --git a/tests/unit/source_code/linters/test_dbfs.py b/tests/unit/source_code/linters/test_dbfs.py index 9ce795eee7..54721dca7f 100644 --- a/tests/unit/source_code/linters/test_dbfs.py +++ b/tests/unit/source_code/linters/test_dbfs.py @@ -1,6 +1,6 @@ import pytest -from databricks.labs.ucx.source_code.base import Deprecation, Advisory, CurrentSessionState +from databricks.labs.ucx.source_code.base import Deprecation, Advisory, CurrentSessionState, Failure from databricks.labs.ucx.source_code.linters.dbfs import DBFSUsageLinter, FromDbfsFolder @@ -91,7 +91,8 @@ def test_non_dbfs_trigger_nothing(query): ) def test_dbfs_tables_trigger_messages_param(query: str, table: str): ftf = FromDbfsFolder() - assert [ + actual = list(ftf.lint(query)) + assert actual == [ Deprecation( code='dbfs-query', message=f'The use of DBFS is deprecated: {table}', @@ -100,7 +101,28 @@ def test_dbfs_tables_trigger_messages_param(query: str, table: str): end_line=0, end_col=1024, ), - ] == list(ftf.lint(query)) + ] + + +@pytest.mark.parametrize( + "query", + [ + 'SELECT * FROM {{some_db.some_table}}', + ], +) +def test_dbfs_queries_failure(query: str): + ftf = FromDbfsFolder() + actual = list(ftf.lint(query)) + assert actual == [ + Failure( + code='dbfs-query', + message=f'SQL query is not supported yet: {query}', + start_line=0, + start_col=0, + end_line=0, + end_col=1024, + ), + ] def test_dbfs_queries_name(): diff --git a/tests/unit/source_code/test_queries.py b/tests/unit/source_code/test_queries.py index 0e5fc628d5..62473e6591 100644 --- a/tests/unit/source_code/test_queries.py +++ b/tests/unit/source_code/test_queries.py @@ -1,4 +1,4 @@ -from databricks.labs.ucx.source_code.base import Deprecation, CurrentSessionState +from databricks.labs.ucx.source_code.base import Deprecation, CurrentSessionState, Failure from databricks.labs.ucx.source_code.queries import FromTable @@ -6,8 +6,8 @@ def test_not_migrated_tables_trigger_nothing(empty_index): ftf = FromTable(empty_index, CurrentSessionState()) old_query = "SELECT * FROM old.things LEFT JOIN hive_metastore.other.matters USING (x) WHERE state > 1 LIMIT 10" - - assert not list(ftf.lint(old_query)) + actual = list(ftf.lint(old_query)) + assert not actual def test_migrated_tables_trigger_messages(migration_index): @@ -85,3 +85,12 @@ def test_parses_create_schema(migration_index): ftf = FromTable(migration_index, session_state=session_state) advices = ftf.lint(query) assert not list(advices) + + +def test_raises_advice_when_parsing_unsupported_sql(migration_index): + query = "DESCRIBE DETAIL xyz" + session_state = CurrentSessionState(schema="old") + ftf = FromTable(migration_index, session_state=session_state) + advices = list(ftf.lint(query)) + assert isinstance(advices[0], Failure) + assert 'not supported' in advices[0].message