Skip to content

Commit

Permalink
catch sqlglot exceptions and convert them to advices (#1915)
Browse files Browse the repository at this point in the history
## Changes
Generate Failure when sqlglot fails to parse query

### Linked issues
Progresses #1901

### Functionality 
None

### Tests
- [x] added unit tests

---------

Co-authored-by: Eric Vergnaud <[email protected]>
Co-authored-by: Andrew Snare <[email protected]>
  • Loading branch information
3 people authored Jun 18, 2024
1 parent deae2df commit 9fc168b
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 50 deletions.
32 changes: 24 additions & 8 deletions src/databricks/labs/ucx/source_code/linters/dbfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from collections.abc import Iterable

from astroid import Call, Const, InferenceError, NodeNG # type: ignore
import sqlglot
from sqlglot import Expression, parse as parse_sql, ParseError as SqlParseError
from sqlglot.expressions import Table

from databricks.labs.ucx.source_code.base import Advice, Linter, Deprecation, CurrentSessionState
from databricks.labs.ucx.source_code.base import Advice, Linter, Deprecation, CurrentSessionState, Failure
from databricks.labs.ucx.source_code.linters.python_ast import Tree, TreeVisitor, InferredValue

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -95,12 +95,28 @@ def name() -> str:
return 'dbfs-query'

def lint(self, code: str) -> Iterable[Advice]:
for statement in sqlglot.parse(code, read='databricks'):
if not statement:
continue
for table in statement.find_all(Table):
# Check table names for deprecated DBFS table names
yield from self._check_dbfs_folder(table)
try:
queries = parse_sql(code, read='databricks')
for query in queries:
if not query:
continue
yield from self._lint_query(query)
except SqlParseError as e:
logger.debug(f"Failed to parse SQL: {code}", exc_info=e)
yield Failure(
code='dbfs-query',
message=f"SQL query is not supported yet: {code}",
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
start_line=0,
start_col=0,
end_line=0,
end_col=1024,
)

def _lint_query(self, query: Expression):
for table in query.find_all(Table):
# Check table names for deprecated DBFS table names
yield from self._check_dbfs_folder(table)

def _check_dbfs_folder(self, table: Table) -> Iterable[Advice]:
"""
Expand Down
88 changes: 52 additions & 36 deletions src/databricks/labs/ucx/source_code/queries.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from collections.abc import Iterable

import logging
import sqlglot
from sqlglot import parse as parse_sql, ParseError as SqlParseError
from sqlglot.expressions import Table, Expression, Use, Create
from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
from databricks.labs.ucx.source_code.base import Advice, Deprecation, Fixer, Linter, CurrentSessionState
from databricks.labs.ucx.source_code.base import Advice, Deprecation, Fixer, Linter, CurrentSessionState, Failure

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,41 +43,57 @@ def schema(self):
return self._session_state.schema

def lint(self, code: str) -> Iterable[Advice]:
for statement in sqlglot.parse(code, read='databricks'):
if not statement:
continue
for table in statement.find_all(Table):
if isinstance(statement, Use):
# Sqlglot captures the database name in the Use statement as a Table, with
# the schema as the table name.
self._session_state.schema = table.name
continue
if isinstance(statement, Create) and statement.kind == "SCHEMA":
# Sqlglot captures the schema name in the Create statement as a Table, with
# the schema as the db name.
self._session_state.schema = table.db
try:
statements = parse_sql(code, read='databricks')
for statement in statements:
if not statement:
continue
yield from self._lint_statement(statement)
except SqlParseError as e:
logger.debug(f"Failed to parse SQL: {code}", exc_info=e)
yield Failure(
code='table-migrate',
message=f"SQL query is not supported yet: {code}",
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
start_line=0,
start_col=0,
end_line=0,
end_col=1024,
)

def _lint_statement(self, statement: Expression):
for table in statement.find_all(Table):
if isinstance(statement, Use):
# Sqlglot captures the database name in the Use statement as a Table, with
# the schema as the table name.
self._session_state.schema = table.name
continue
if isinstance(statement, Create) and statement.kind == "SCHEMA":
# Sqlglot captures the schema name in the Create statement as a Table, with
# the schema as the db name.
self._session_state.schema = table.db
continue

# we only migrate tables in the hive_metastore catalog
if self._catalog(table) != 'hive_metastore':
continue
# Sqlglot uses db instead of schema, watch out for that
src_schema = table.db if table.db else self._session_state.schema
if not src_schema:
logger.error(f"Could not determine schema for table {table.name}")
continue
dst = self._index.get(src_schema, table.name)
if not dst:
continue
yield Deprecation(
code='table-migrate',
message=f"Table {src_schema}.{table.name} is migrated to {dst.destination()} in Unity Catalog",
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
start_line=0,
start_col=0,
end_line=0,
end_col=1024,
)
# we only migrate tables in the hive_metastore catalog
if self._catalog(table) != 'hive_metastore':
continue
# Sqlglot uses db instead of schema, watch out for that
src_schema = table.db if table.db else self._session_state.schema
if not src_schema:
logger.error(f"Could not determine schema for table {table.name}")
continue
dst = self._index.get(src_schema, table.name)
if not dst:
continue
yield Deprecation(
code='table-migrate',
message=f"Table {src_schema}.{table.name} is migrated to {dst.destination()} in Unity Catalog",
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
start_line=0,
start_col=0,
end_line=0,
end_col=1024,
)

@staticmethod
def _catalog(table):
Expand All @@ -87,7 +103,7 @@ def _catalog(table):

def apply(self, code: str) -> str:
new_statements = []
for statement in sqlglot.parse(code, read='databricks'):
for statement in parse_sql(code, read='databricks'):
if not statement:
continue
if isinstance(statement, Use):
Expand Down
28 changes: 25 additions & 3 deletions tests/unit/source_code/linters/test_dbfs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from databricks.labs.ucx.source_code.base import Deprecation, Advisory, CurrentSessionState
from databricks.labs.ucx.source_code.base import Deprecation, Advisory, CurrentSessionState, Failure
from databricks.labs.ucx.source_code.linters.dbfs import DBFSUsageLinter, FromDbfsFolder


Expand Down Expand Up @@ -91,7 +91,8 @@ def test_non_dbfs_trigger_nothing(query):
)
def test_dbfs_tables_trigger_messages_param(query: str, table: str):
ftf = FromDbfsFolder()
assert [
actual = list(ftf.lint(query))
assert actual == [
Deprecation(
code='dbfs-query',
message=f'The use of DBFS is deprecated: {table}',
Expand All @@ -100,7 +101,28 @@ def test_dbfs_tables_trigger_messages_param(query: str, table: str):
end_line=0,
end_col=1024,
),
] == list(ftf.lint(query))
]


@pytest.mark.parametrize(
"query",
[
'SELECT * FROM {{some_db.some_table}}',
],
)
def test_dbfs_queries_failure(query: str):
ftf = FromDbfsFolder()
actual = list(ftf.lint(query))
assert actual == [
Failure(
code='dbfs-query',
message=f'SQL query is not supported yet: {query}',
start_line=0,
start_col=0,
end_line=0,
end_col=1024,
),
]


def test_dbfs_queries_name():
Expand Down
15 changes: 12 additions & 3 deletions tests/unit/source_code/test_queries.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from databricks.labs.ucx.source_code.base import Deprecation, CurrentSessionState
from databricks.labs.ucx.source_code.base import Deprecation, CurrentSessionState, Failure
from databricks.labs.ucx.source_code.queries import FromTable


def test_not_migrated_tables_trigger_nothing(empty_index):
ftf = FromTable(empty_index, CurrentSessionState())

old_query = "SELECT * FROM old.things LEFT JOIN hive_metastore.other.matters USING (x) WHERE state > 1 LIMIT 10"

assert not list(ftf.lint(old_query))
actual = list(ftf.lint(old_query))
assert not actual


def test_migrated_tables_trigger_messages(migration_index):
Expand Down Expand Up @@ -85,3 +85,12 @@ def test_parses_create_schema(migration_index):
ftf = FromTable(migration_index, session_state=session_state)
advices = ftf.lint(query)
assert not list(advices)


def test_raises_advice_when_parsing_unsupported_sql(migration_index):
query = "DESCRIBE DETAIL xyz"
session_state = CurrentSessionState(schema="old")
ftf = FromTable(migration_index, session_state=session_state)
advices = list(ftf.lint(query))
assert isinstance(advices[0], Failure)
assert 'not supported' in advices[0].message

0 comments on commit 9fc168b

Please sign in to comment.