diff --git a/src/databricks/labs/ucx/source_code/languages.py b/src/databricks/labs/ucx/source_code/languages.py index 2b022623f7..a78326cfd5 100644 --- a/src/databricks/labs/ucx/source_code/languages.py +++ b/src/databricks/labs/ucx/source_code/languages.py @@ -11,11 +11,11 @@ def __init__(self, index: MigrationIndex): self._index = index from_table = FromTable(index) self._linters = { - Language.PYTHON: SequentialLinter([SparkSql(from_table)]), + Language.PYTHON: SequentialLinter([SparkSql(from_table, index)]), Language.SQL: SequentialLinter([from_table]), } self._fixers: dict[Language, list[Fixer]] = { - Language.PYTHON: [SparkSql(from_table)], + Language.PYTHON: [SparkSql(from_table, index)], Language.SQL: [from_table], } diff --git a/src/databricks/labs/ucx/source_code/pyspark.py b/src/databricks/labs/ucx/source_code/pyspark.py index 3660fd725a..8433c0ad4b 100644 --- a/src/databricks/labs/ucx/source_code/pyspark.py +++ b/src/databricks/labs/ucx/source_code/pyspark.py @@ -1,13 +1,220 @@ import ast -from collections.abc import Iterable +from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator +from dataclasses import dataclass -from databricks.labs.ucx.source_code.base import Advice, Fixer, Linter +from databricks.labs.ucx.hive_metastore.table_migrate import MigrationIndex +from databricks.labs.ucx.source_code.base import ( + Advice, + Advisory, + Deprecation, + Fixer, + Linter, +) from databricks.labs.ucx.source_code.queries import FromTable +@dataclass +class Matcher(ABC): + method_name: str + min_args: int + max_args: int + table_arg_index: int + table_arg_name: str | None = None + + def matches(self, node: ast.AST): + if not (isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute)): + return False + return self._get_table_arg(node) is not None + + @abstractmethod + def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]: + raise NotImplementedError() + + @abstractmethod + def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None: + raise NotImplementedError() + + def _get_table_arg(self, node: ast.Call): + if len(node.args) > 0: + return node.args[self.table_arg_index] if self.min_args <= len(node.args) <= self.max_args else None + assert self.table_arg_name is not None + arg = next(kw for kw in node.keywords if kw.arg == self.table_arg_name) + return arg.value if arg is not None else None + + +@dataclass +class QueryMatcher(Matcher): + + def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]: + table_arg = self._get_table_arg(node) + if isinstance(table_arg, ast.Constant): + for advice in from_table.lint(table_arg.value): + yield advice.replace( + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno, + end_col=node.end_col_offset, + ) + else: + yield Advisory( + code='table-migrate', + message=f"Can't migrate '{node}' because its table name argument is not a constant", + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno or 0, + end_col=node.end_col_offset or 0, + ) + + def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None: + table_arg = self._get_table_arg(node) + assert isinstance(table_arg, ast.Constant) + new_query = from_table.apply(table_arg.value) + table_arg.value = new_query + + +@dataclass +class TableNameMatcher(Matcher): + + def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]: + table_arg = self._get_table_arg(node) + if isinstance(table_arg, ast.Constant): + dst = self._find_dest(index, table_arg.value) + if dst is not None: + yield Deprecation( + code='table-migrate', + message=f"Table {table_arg.value} is migrated to {dst.destination()} in Unity Catalog", + # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno or 0, + end_col=node.end_col_offset or 0, + ) + else: + assert isinstance(node.func, ast.Attribute) # always true, avoids a pylint warning + yield Advisory( + code='table-migrate', + message=f"Can't migrate '{node.func.attr}' because its table name argument is not a constant", + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno or 0, + end_col=node.end_col_offset or 0, + ) + + def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None: + table_arg = self._get_table_arg(node) + assert isinstance(table_arg, ast.Constant) + dst = self._find_dest(index, table_arg.value) + if dst is not None: + table_arg.value = dst.destination() + + @staticmethod + def _find_dest(index: MigrationIndex, value: str): + parts = value.split(".") + return None if len(parts) != 2 else index.get(parts[0], parts[1]) + + +@dataclass +class ReturnValueMatcher(Matcher): + + def matches(self, node: ast.AST): + return isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) + + def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]: + assert isinstance(node.func, ast.Attribute) # always true, avoids a pylint warning + yield Advisory( + code='table-migrate', + message=f"Call to '{node.func.attr}' will return a list of .. instead of .
.", + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno or 0, + end_col=node.end_col_offset or 0, + ) + + def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None: + raise NotImplementedError("Should never get there!") + + +class SparkMatchers: + + def __init__(self): + # see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.SparkSession.html + spark_session_matchers = [QueryMatcher("sql", 1, 1000, 0, "sqlQuery"), TableNameMatcher("table", 1, 1, 0)] + + # see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Catalog.html + spark_catalog_matchers = [ + TableNameMatcher("cacheTable", 1, 2, 0, "tableName"), + TableNameMatcher("createTable", 1, 1000, 0, "tableName"), + TableNameMatcher("createExternalTable", 1, 1000, 0, "tableName"), + TableNameMatcher("getTable", 1, 1, 0), + TableNameMatcher("isCached", 1, 1, 0), + TableNameMatcher("listColumns", 1, 2, 0, "tableName"), + TableNameMatcher("tableExists", 1, 2, 0, "tableName"), + TableNameMatcher("recoverPartitions", 1, 1, 0), + TableNameMatcher("refreshTable", 1, 1, 0), + TableNameMatcher("uncacheTable", 1, 1, 0), + ReturnValueMatcher("listTables", 0, 2, -1), + ] + + # see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.html + spark_dataframe_matchers = [ + TableNameMatcher("writeTo", 1, 1, 0), + ] + + # nothing to migrate in Column, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Column.html + # nothing to migrate in Observation, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Observation.html + # nothing to migrate in Row, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Row.html + # nothing to migrate in GroupedData, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.GroupedData.html + # nothing to migrate in PandasCogroupedOps, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.PandasCogroupedOps.html + # nothing to migrate in DataFrameNaFunctions, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameNaFunctions.html + # nothing to migrate in DataFrameStatFunctions, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameStatFunctions.html + # nothing to migrate in Window, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Window.html + + # see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.html + spark_dataframereader_matchers = [ + TableNameMatcher("table", 1, 1, 0), # TODO good example of collision, see spark_session_calls + ] + + # see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriter.html + spark_dataframewriter_matchers = [ + TableNameMatcher("insertInto", 1, 2, 0, "tableName"), + # TODO jdbc: could the url be a databricks url, raise warning ? + TableNameMatcher("saveAsTable", 1, 4, 0, "name"), + ] + + # nothing to migrate in DataFrameWriterV2, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriterV2.html + # nothing to migrate in UDFRegistration, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.UDFRegistration.html + + # see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.UDTFRegistration.html + spark_udtfregistration_matchers = [ + TableNameMatcher("register", 1, 2, 0, "name"), + ] + + # nothing to migrate in UserDefinedFunction, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.UserDefinedFunction.html + # nothing to migrate in UserDefinedTableFunction, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.UserDefinedTableFunction.html + self._matchers = {} + for matcher in ( + spark_session_matchers + + spark_catalog_matchers + + spark_dataframe_matchers + + spark_dataframereader_matchers + + spark_dataframewriter_matchers + + spark_udtfregistration_matchers + ): + self._matchers[matcher.method_name] = matcher + + @property + def matchers(self): + return self._matchers + + class SparkSql(Linter, Fixer): - def __init__(self, from_table: FromTable): + + _spark_matchers = SparkMatchers() + + def __init__(self, from_table: FromTable, index: MigrationIndex): self._from_table = from_table + self._index = index def name(self) -> str: # this is the same fixer, just in a different language context @@ -16,44 +223,29 @@ def name(self) -> str: def lint(self, code: str) -> Iterable[Advice]: tree = ast.parse(code) for node in ast.walk(tree): - if not isinstance(node, ast.Call): - continue - if not isinstance(node.func, ast.Attribute): + matcher = self._find_matcher(node) + if matcher is None: continue - if node.func.attr != "sql": - continue - if len(node.args) != 1: - continue - first_arg = node.args[0] - if not isinstance(first_arg, ast.Constant): - # `astroid` library supports inference and parent node lookup, - # which makes traversing the AST a bit easier. - continue - query = first_arg.value - for advice in self._from_table.lint(query): - yield advice.replace( - start_line=node.lineno, - start_col=node.col_offset, - end_line=node.end_lineno, - end_col=node.end_col_offset, - ) + assert isinstance(node, ast.Call) + yield from matcher.lint(self._from_table, self._index, node) def apply(self, code: str) -> str: tree = ast.parse(code) # we won't be doing it like this in production, but for the sake of the example for node in ast.walk(tree): - if not isinstance(node, ast.Call): - continue - if not isinstance(node.func, ast.Attribute): - continue - if node.func.attr != "sql": + matcher = self._find_matcher(node) + if matcher is None: continue - if len(node.args) != 1: - continue - first_arg = node.args[0] - if not isinstance(first_arg, ast.Constant): - continue - query = first_arg.value - new_query = self._from_table.apply(query) - first_arg.value = new_query + assert isinstance(node, ast.Call) + matcher.apply(self._from_table, self._index, node) return ast.unparse(tree) + + def _find_matcher(self, node: ast.AST): + if not isinstance(node, ast.Call): + return None + if not isinstance(node.func, ast.Attribute): + return None + matcher = self._spark_matchers.matchers.get(node.func.attr, None) + if matcher is None: + return None + return matcher if matcher.matches(node) else None diff --git a/tests/unit/source_code/test_pyspark.py b/tests/unit/source_code/test_pyspark.py index 22b08a794a..5deec73178 100644 --- a/tests/unit/source_code/test_pyspark.py +++ b/tests/unit/source_code/test_pyspark.py @@ -1,18 +1,20 @@ -from databricks.labs.ucx.source_code.base import Deprecation -from databricks.labs.ucx.source_code.pyspark import SparkSql +import pytest + +from databricks.labs.ucx.source_code.base import Advisory, Deprecation +from databricks.labs.ucx.source_code.pyspark import SparkMatchers, SparkSql from databricks.labs.ucx.source_code.queries import FromTable -def test_spark_not_sql(empty_index): +def test_spark_no_sql(empty_index): ftf = FromTable(empty_index) - sqf = SparkSql(ftf) + sqf = SparkSql(ftf, empty_index) assert not list(sqf.lint("print(1)")) def test_spark_sql_no_match(empty_index): ftf = FromTable(empty_index) - sqf = SparkSql(ftf) + sqf = SparkSql(ftf, empty_index) old_code = """ spark.read.csv("s3://bucket/path") @@ -26,7 +28,7 @@ def test_spark_sql_no_match(empty_index): def test_spark_sql_match(migration_index): ftf = FromTable(migration_index) - sqf = SparkSql(ftf) + sqf = SparkSql(ftf, migration_index) old_code = """ spark.read.csv("s3://bucket/path") @@ -46,9 +48,197 @@ def test_spark_sql_match(migration_index): ] == list(sqf.lint(old_code)) +def test_spark_sql_match_named(migration_index): + ftf = FromTable(migration_index) + sqf = SparkSql(ftf, migration_index) + + old_code = """ +spark.read.csv("s3://bucket/path") +for i in range(10): + result = spark.sql(args=[1], sqlQuery = "SELECT * FROM old.things").collect() + print(len(result)) +""" + assert [ + Deprecation( + code='table-migrate', + message='Table old.things is migrated to brand.new.stuff in Unity Catalog', + start_line=4, + start_col=13, + end_line=4, + end_col=71, + ) + ] == list(sqf.lint(old_code)) + + +METHOD_NAMES = [ + "cacheTable", + "createTable", + "createExternalTable", + "getTable", + "isCached", + "listColumns", + "tableExists", + "recoverPartitions", + "refreshTable", + "uncacheTable", + "table", + "insertInto", + "saveAsTable", + "register", +] + + +@pytest.mark.parametrize("method_name", METHOD_NAMES) +def test_spark_table_match(migration_index, method_name): + spark_matchers = SparkMatchers() + ftf = FromTable(migration_index) + sqf = SparkSql(ftf, migration_index) + matcher = spark_matchers.matchers[method_name] + args_list = ["a"] * min(5, matcher.max_args) + args_list[matcher.table_arg_index] = '"old.things"' + args = ",".join(args_list) + old_code = f""" +spark.read.csv("s3://bucket/path") +for i in range(10): + df = spark.{method_name}({args}) + do_stuff_with_df(df) +""" + assert [ + Deprecation( + code='table-migrate', + message='Table old.things is migrated to brand.new.stuff in Unity Catalog', + start_line=4, + start_col=9, + end_line=4, + end_col=17 + len(method_name) + len(args), + ) + ] == list(sqf.lint(old_code)) + + +@pytest.mark.parametrize("method_name", METHOD_NAMES) +def test_spark_table_no_match(migration_index, method_name): + spark_matchers = SparkMatchers() + ftf = FromTable(migration_index) + sqf = SparkSql(ftf, migration_index) + matcher = spark_matchers.matchers[method_name] + args_list = ["a"] * min(5, matcher.max_args) + args_list[matcher.table_arg_index] = '"table.we.know.nothing.about"' + args = ",".join(args_list) + old_code = f""" +spark.read.csv("s3://bucket/path") +for i in range(10): + df = spark.{method_name}({args}) + do_stuff_with_df(df) +""" + assert not list(sqf.lint(old_code)) + + +@pytest.mark.parametrize("method_name", METHOD_NAMES) +def test_spark_table_too_many_args(migration_index, method_name): + spark_matchers = SparkMatchers() + ftf = FromTable(migration_index) + sqf = SparkSql(ftf, migration_index) + matcher = spark_matchers.matchers[method_name] + if matcher.max_args > 100: + return + args_list = ["a"] * (matcher.max_args + 1) + args_list[matcher.table_arg_index] = '"table.we.know.nothing.about"' + args = ",".join(args_list) + old_code = f""" +spark.read.csv("s3://bucket/path") +for i in range(10): + df = spark.{method_name}({args}) + do_stuff_with_df(df) +""" + assert not list(sqf.lint(old_code)) + + +def test_spark_table_named_args(migration_index): + ftf = FromTable(migration_index) + sqf = SparkSql(ftf, migration_index) + old_code = """ +spark.read.csv("s3://bucket/path") +for i in range(10): + df = spark.saveAsTable(format="xyz", name="old.things") + do_stuff_with_df(df) +""" + assert [ + Deprecation( + code='table-migrate', + message='Table old.things is migrated to brand.new.stuff in Unity Catalog', + start_line=4, + start_col=9, + end_line=4, + end_col=59, + ) + ] == list(sqf.lint(old_code)) + + +def test_spark_table_variable_arg(migration_index): + ftf = FromTable(migration_index) + sqf = SparkSql(ftf, migration_index) + old_code = """ +spark.read.csv("s3://bucket/path") +for i in range(10): + df = spark.saveAsTable(name) + do_stuff_with_df(df) +""" + assert [ + Advisory( + code='table-migrate', + message="Can't migrate 'saveAsTable' because its table name argument is not a constant", + start_line=4, + start_col=9, + end_line=4, + end_col=32, + ) + ] == list(sqf.lint(old_code)) + + +def test_spark_table_fstring_arg(migration_index): + ftf = FromTable(migration_index) + sqf = SparkSql(ftf, migration_index) + old_code = """ +spark.read.csv("s3://bucket/path") +for i in range(10): + df = spark.saveAsTable(f"boop{stuff}") + do_stuff_with_df(df) +""" + assert [ + Advisory( + code='table-migrate', + message="Can't migrate 'saveAsTable' because its table name argument is not a constant", + start_line=4, + start_col=9, + end_line=4, + end_col=42, + ) + ] == list(sqf.lint(old_code)) + + +def test_spark_table_return_value(migration_index): + ftf = FromTable(migration_index) + sqf = SparkSql(ftf, migration_index) + old_code = """ +spark.read.csv("s3://bucket/path") +for table in spark.listTables(): + do_stuff_with_table(table) +""" + assert [ + Advisory( + code='table-migrate', + message="Call to 'listTables' will return a list of ..
instead of .
.", + start_line=3, + start_col=13, + end_line=3, + end_col=31, + ) + ] == list(sqf.lint(old_code)) + + def test_spark_sql_fix(migration_index): ftf = FromTable(migration_index) - sqf = SparkSql(ftf) + sqf = SparkSql(ftf, migration_index) old_code = """spark.read.csv("s3://bucket/path") for i in range(10):