Skip to content

Commit

Permalink
Support table migration to Unity Catalog in Python code (#1210)
Browse files Browse the repository at this point in the history
## Changes
Enhance SparkSql linter/fixer to support migration of spark.sql function
calls that receive a table name as parameter

### Linked issues
Resolves # 1082

### Functionality 

- [ ] added relevant user documentation
- [ ] added new CLI command
- [x] modified existing command: `databricks labs ucx
migrate_local_code`
- [ ] added a new workflow
- [ ] modified existing workflow: `...`
- [ ] added a new table
- [ ] modified existing table: `...`

### Tests
- [ ] manually tested
- [x] added unit tests
- [ ] added integration tests
- [ ] verified on staging environment (screenshot attached)

missing integration tests

---------

Co-authored-by: Serge Smertin <[email protected]>
  • Loading branch information
ericvergnaud and nfx authored Apr 2, 2024
1 parent b20abdc commit de05703
Show file tree
Hide file tree
Showing 3 changed files with 427 additions and 45 deletions.
4 changes: 2 additions & 2 deletions src/databricks/labs/ucx/source_code/languages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ def __init__(self, index: MigrationIndex):
self._index = index
from_table = FromTable(index)
self._linters = {
Language.PYTHON: SequentialLinter([SparkSql(from_table)]),
Language.PYTHON: SequentialLinter([SparkSql(from_table, index)]),
Language.SQL: SequentialLinter([from_table]),
}
self._fixers: dict[Language, list[Fixer]] = {
Language.PYTHON: [SparkSql(from_table)],
Language.PYTHON: [SparkSql(from_table, index)],
Language.SQL: [from_table],
}

Expand Down
264 changes: 228 additions & 36 deletions src/databricks/labs/ucx/source_code/pyspark.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,220 @@
import ast
from collections.abc import Iterable
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from dataclasses import dataclass

from databricks.labs.ucx.source_code.base import Advice, Fixer, Linter
from databricks.labs.ucx.hive_metastore.table_migrate import MigrationIndex
from databricks.labs.ucx.source_code.base import (
Advice,
Advisory,
Deprecation,
Fixer,
Linter,
)
from databricks.labs.ucx.source_code.queries import FromTable


@dataclass
class Matcher(ABC):
method_name: str
min_args: int
max_args: int
table_arg_index: int
table_arg_name: str | None = None

def matches(self, node: ast.AST):
if not (isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute)):
return False
return self._get_table_arg(node) is not None

@abstractmethod
def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]:
raise NotImplementedError()

@abstractmethod
def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None:
raise NotImplementedError()

def _get_table_arg(self, node: ast.Call):
if len(node.args) > 0:
return node.args[self.table_arg_index] if self.min_args <= len(node.args) <= self.max_args else None
assert self.table_arg_name is not None
arg = next(kw for kw in node.keywords if kw.arg == self.table_arg_name)
return arg.value if arg is not None else None


@dataclass
class QueryMatcher(Matcher):

def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]:
table_arg = self._get_table_arg(node)
if isinstance(table_arg, ast.Constant):
for advice in from_table.lint(table_arg.value):
yield advice.replace(
start_line=node.lineno,
start_col=node.col_offset,
end_line=node.end_lineno,
end_col=node.end_col_offset,
)
else:
yield Advisory(
code='table-migrate',
message=f"Can't migrate '{node}' because its table name argument is not a constant",
start_line=node.lineno,
start_col=node.col_offset,
end_line=node.end_lineno or 0,
end_col=node.end_col_offset or 0,
)

def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None:
table_arg = self._get_table_arg(node)
assert isinstance(table_arg, ast.Constant)
new_query = from_table.apply(table_arg.value)
table_arg.value = new_query


@dataclass
class TableNameMatcher(Matcher):

def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]:
table_arg = self._get_table_arg(node)
if isinstance(table_arg, ast.Constant):
dst = self._find_dest(index, table_arg.value)
if dst is not None:
yield Deprecation(
code='table-migrate',
message=f"Table {table_arg.value} is migrated to {dst.destination()} in Unity Catalog",
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
start_line=node.lineno,
start_col=node.col_offset,
end_line=node.end_lineno or 0,
end_col=node.end_col_offset or 0,
)
else:
assert isinstance(node.func, ast.Attribute) # always true, avoids a pylint warning
yield Advisory(
code='table-migrate',
message=f"Can't migrate '{node.func.attr}' because its table name argument is not a constant",
start_line=node.lineno,
start_col=node.col_offset,
end_line=node.end_lineno or 0,
end_col=node.end_col_offset or 0,
)

def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None:
table_arg = self._get_table_arg(node)
assert isinstance(table_arg, ast.Constant)
dst = self._find_dest(index, table_arg.value)
if dst is not None:
table_arg.value = dst.destination()

@staticmethod
def _find_dest(index: MigrationIndex, value: str):
parts = value.split(".")
return None if len(parts) != 2 else index.get(parts[0], parts[1])


@dataclass
class ReturnValueMatcher(Matcher):

def matches(self, node: ast.AST):
return isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute)

def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]:
assert isinstance(node.func, ast.Attribute) # always true, avoids a pylint warning
yield Advisory(
code='table-migrate',
message=f"Call to '{node.func.attr}' will return a list of <catalog>.<database>.<table> instead of <database>.<table>.",
start_line=node.lineno,
start_col=node.col_offset,
end_line=node.end_lineno or 0,
end_col=node.end_col_offset or 0,
)

def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None:
raise NotImplementedError("Should never get there!")


class SparkMatchers:

def __init__(self):
# see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.SparkSession.html
spark_session_matchers = [QueryMatcher("sql", 1, 1000, 0, "sqlQuery"), TableNameMatcher("table", 1, 1, 0)]

# see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Catalog.html
spark_catalog_matchers = [
TableNameMatcher("cacheTable", 1, 2, 0, "tableName"),
TableNameMatcher("createTable", 1, 1000, 0, "tableName"),
TableNameMatcher("createExternalTable", 1, 1000, 0, "tableName"),
TableNameMatcher("getTable", 1, 1, 0),
TableNameMatcher("isCached", 1, 1, 0),
TableNameMatcher("listColumns", 1, 2, 0, "tableName"),
TableNameMatcher("tableExists", 1, 2, 0, "tableName"),
TableNameMatcher("recoverPartitions", 1, 1, 0),
TableNameMatcher("refreshTable", 1, 1, 0),
TableNameMatcher("uncacheTable", 1, 1, 0),
ReturnValueMatcher("listTables", 0, 2, -1),
]

# see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.html
spark_dataframe_matchers = [
TableNameMatcher("writeTo", 1, 1, 0),
]

# nothing to migrate in Column, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Column.html
# nothing to migrate in Observation, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Observation.html
# nothing to migrate in Row, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Row.html
# nothing to migrate in GroupedData, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.GroupedData.html
# nothing to migrate in PandasCogroupedOps, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.PandasCogroupedOps.html
# nothing to migrate in DataFrameNaFunctions, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameNaFunctions.html
# nothing to migrate in DataFrameStatFunctions, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameStatFunctions.html
# nothing to migrate in Window, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Window.html

# see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.html
spark_dataframereader_matchers = [
TableNameMatcher("table", 1, 1, 0), # TODO good example of collision, see spark_session_calls
]

# see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriter.html
spark_dataframewriter_matchers = [
TableNameMatcher("insertInto", 1, 2, 0, "tableName"),
# TODO jdbc: could the url be a databricks url, raise warning ?
TableNameMatcher("saveAsTable", 1, 4, 0, "name"),
]

# nothing to migrate in DataFrameWriterV2, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriterV2.html
# nothing to migrate in UDFRegistration, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.UDFRegistration.html

# see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.UDTFRegistration.html
spark_udtfregistration_matchers = [
TableNameMatcher("register", 1, 2, 0, "name"),
]

# nothing to migrate in UserDefinedFunction, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.UserDefinedFunction.html
# nothing to migrate in UserDefinedTableFunction, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.UserDefinedTableFunction.html
self._matchers = {}
for matcher in (
spark_session_matchers
+ spark_catalog_matchers
+ spark_dataframe_matchers
+ spark_dataframereader_matchers
+ spark_dataframewriter_matchers
+ spark_udtfregistration_matchers
):
self._matchers[matcher.method_name] = matcher

@property
def matchers(self):
return self._matchers


class SparkSql(Linter, Fixer):
def __init__(self, from_table: FromTable):

_spark_matchers = SparkMatchers()

def __init__(self, from_table: FromTable, index: MigrationIndex):
self._from_table = from_table
self._index = index

def name(self) -> str:
# this is the same fixer, just in a different language context
Expand All @@ -16,44 +223,29 @@ def name(self) -> str:
def lint(self, code: str) -> Iterable[Advice]:
tree = ast.parse(code)
for node in ast.walk(tree):
if not isinstance(node, ast.Call):
continue
if not isinstance(node.func, ast.Attribute):
matcher = self._find_matcher(node)
if matcher is None:
continue
if node.func.attr != "sql":
continue
if len(node.args) != 1:
continue
first_arg = node.args[0]
if not isinstance(first_arg, ast.Constant):
# `astroid` library supports inference and parent node lookup,
# which makes traversing the AST a bit easier.
continue
query = first_arg.value
for advice in self._from_table.lint(query):
yield advice.replace(
start_line=node.lineno,
start_col=node.col_offset,
end_line=node.end_lineno,
end_col=node.end_col_offset,
)
assert isinstance(node, ast.Call)
yield from matcher.lint(self._from_table, self._index, node)

def apply(self, code: str) -> str:
tree = ast.parse(code)
# we won't be doing it like this in production, but for the sake of the example
for node in ast.walk(tree):
if not isinstance(node, ast.Call):
continue
if not isinstance(node.func, ast.Attribute):
continue
if node.func.attr != "sql":
matcher = self._find_matcher(node)
if matcher is None:
continue
if len(node.args) != 1:
continue
first_arg = node.args[0]
if not isinstance(first_arg, ast.Constant):
continue
query = first_arg.value
new_query = self._from_table.apply(query)
first_arg.value = new_query
assert isinstance(node, ast.Call)
matcher.apply(self._from_table, self._index, node)
return ast.unparse(tree)

def _find_matcher(self, node: ast.AST):
if not isinstance(node, ast.Call):
return None
if not isinstance(node.func, ast.Attribute):
return None
matcher = self._spark_matchers.matchers.get(node.func.attr, None)
if matcher is None:
return None
return matcher if matcher.matches(node) else None
Loading

0 comments on commit de05703

Please sign in to comment.