From f892e1e33009b3542efdeb47666e0a7595326090 Mon Sep 17 00:00:00 2001 From: Vsevolod Stepanov Date: Fri, 5 Jul 2024 17:30:01 +0200 Subject: [PATCH 1/3] Add more checks for spark-connect linter --- .../labs/ucx/source_code/linters/context.py | 12 ++- .../ucx/source_code/linters/python_ast.py | 10 +++ .../ucx/source_code/linters/spark_connect.py | 74 ++++++++++++++++++- .../spark-connect/catalog-api_13_3.py | 11 +++ .../spark-connect/catalog-api_14_3.py | 9 +++ .../spark-connect/command-context.py | 6 ++ .../{ => spark-connect}/jvm-access.py | 0 .../spark-connect/python-udfs_13_3.py | 52 +++++++++++++ .../spark-connect/python-udfs_14_3.py | 38 ++++++++++ .../{ => spark-connect}/rdd-apis.py | 0 .../{ => spark-connect}/spark-logging.py | 0 tests/unit/source_code/test_functional.py | 11 ++- 12 files changed, 217 insertions(+), 6 deletions(-) create mode 100644 tests/unit/source_code/samples/functional/spark-connect/catalog-api_13_3.py create mode 100644 tests/unit/source_code/samples/functional/spark-connect/catalog-api_14_3.py create mode 100644 tests/unit/source_code/samples/functional/spark-connect/command-context.py rename tests/unit/source_code/samples/functional/{ => spark-connect}/jvm-access.py (100%) create mode 100644 tests/unit/source_code/samples/functional/spark-connect/python-udfs_13_3.py create mode 100644 tests/unit/source_code/samples/functional/spark-connect/python-udfs_14_3.py rename tests/unit/source_code/samples/functional/{ => spark-connect}/rdd-apis.py (100%) rename tests/unit/source_code/samples/functional/{ => spark-connect}/spark-logging.py (100%) diff --git a/src/databricks/labs/ucx/source_code/linters/context.py b/src/databricks/labs/ucx/source_code/linters/context.py index fad60107f9..7303b4785e 100644 --- a/src/databricks/labs/ucx/source_code/linters/context.py +++ b/src/databricks/labs/ucx/source_code/linters/context.py @@ -21,7 +21,13 @@ class LinterContext: - def __init__(self, index: MigrationIndex | None = None, session_state: CurrentSessionState | None = None): + def __init__( + self, + index: MigrationIndex | None = None, + session_state: CurrentSessionState | None = None, + dbr_version: tuple[int, int] | None = None, + is_serverless: bool = False, + ): self._index = index session_state = CurrentSessionState() if not session_state else session_state @@ -41,8 +47,8 @@ def __init__(self, index: MigrationIndex | None = None, session_state: CurrentSe python_linters += [ DBFSUsageLinter(session_state), - DBRv8d0Linter(dbr_version=None), - SparkConnectLinter(is_serverless=False), + DBRv8d0Linter(dbr_version=dbr_version), + SparkConnectLinter(dbr_version=dbr_version, is_serverless=is_serverless), DbutilsLinter(session_state), ] sql_linters.append(FromDbfsFolder()) diff --git a/src/databricks/labs/ucx/source_code/linters/python_ast.py b/src/databricks/labs/ucx/source_code/linters/python_ast.py index 97959c3687..704a17293f 100644 --- a/src/databricks/labs/ucx/source_code/linters/python_ast.py +++ b/src/databricks/labs/ucx/source_code/linters/python_ast.py @@ -172,6 +172,16 @@ def __repr__(self): def get_full_attribute_name(cls, node: Attribute) -> str: return cls._get_attribute_value(node) + @classmethod + def get_function_name(cls, node: Call) -> str | None: + if not isinstance(node, Call): + return None + if isinstance(node.func, Attribute): + return node.func.attrname + if isinstance(node.func, Name): + return node.func.name + return None + @classmethod def get_full_function_name(cls, node: Call) -> str | None: if not isinstance(node, Call): diff --git a/src/databricks/labs/ucx/source_code/linters/spark_connect.py b/src/databricks/labs/ucx/source_code/linters/spark_connect.py index 962057d2c8..696dfea421 100644 --- a/src/databricks/labs/ucx/source_code/linters/spark_connect.py +++ b/src/databricks/labs/ucx/source_code/linters/spark_connect.py @@ -2,7 +2,7 @@ from collections.abc import Iterator from dataclasses import dataclass -from astroid import Attribute, Call, Name, NodeNG # type: ignore +from astroid import Attribute, Call, Const, Name, NodeNG # type: ignore from databricks.labs.ucx.source_code.base import ( Advice, Failure, @@ -172,14 +172,84 @@ def _match_jvm_log(self, node: NodeNG) -> Iterator[Advice]: ) +@dataclass +class UDFMatcher(SharedClusterMatcher): + _DBR_14_2_BELOW_NOT_SUPPORTED = ["applyInPandas", "mapInPandas", "applyInPandasWithState", "udtf", "pandas_udf"] + + dbr_version: tuple[int, int] | None + + def lint(self, node: NodeNG) -> Iterator[Advice]: + if not isinstance(node, Call): + return + function_name = Tree.get_function_name(node) + + if function_name == 'registerJavaFunction': + yield Failure.from_node( + code='python-udf-in-shared-clusters', + message=f'Cannot register Java UDF from Python code on {self._cluster_type_str()}. ' + f'Use a %scala cell to register the Scala UDF using spark.udf.register.', + node=node, + ) + + if ( + function_name in UDFMatcher._DBR_14_2_BELOW_NOT_SUPPORTED + and self.dbr_version + and self.dbr_version < (14, 3) + ): + yield Failure.from_node( + code='python-udf-in-shared-clusters', + message=f'{function_name} require DBR 14.3 LTS or above on {self._cluster_type_str()}', + node=node, + ) + + if function_name == 'udf' and self.dbr_version and self.dbr_version < (14, 3): + for keyword in node.keywords: + if keyword.arg == 'useArrow' and isinstance(keyword.value, Const) and keyword.value.value: + yield Failure.from_node( + code='python-udf-in-shared-clusters', + message=f'Arrow UDFs require DBR 14.3 LTS or above on {self._cluster_type_str()}', + node=node, + ) + + +class CatalogApiMatcher(SharedClusterMatcher): + def lint(self, node: NodeNG) -> Iterator[Advice]: + if not isinstance(node, Attribute): + return + if node.attrname == 'catalog' and Tree.get_full_attribute_name(node).endswith('spark.catalog'): + yield Failure.from_node( + code='catalog-api-in-shared-clusters', + message=f'spark.catalog functions require DBR 14.3 LTS or above on {self._cluster_type_str()}', + node=node, + ) + + +class CommandContextMatcher(SharedClusterMatcher): + def lint(self, node: NodeNG) -> Iterator[Advice]: + if not isinstance(node, Call): + return + function_name = Tree.get_full_function_name(node) + if function_name and function_name.endswith('getContext.toJson'): + yield Failure.from_node( + code='toJson-in-shared-clusters', + message=f'toJson() is not available on {self._cluster_type_str()}. ' + f'Use toSafeJson() on DBR 13.3 LTS or above to get a subset of command context information.', + node=node, + ) + + class SparkConnectLinter(PythonLinter): - def __init__(self, is_serverless: bool = False): + def __init__(self, dbr_version: tuple[int, int] | None = None, is_serverless: bool = False): self._matchers = [ JvmAccessMatcher(is_serverless=is_serverless), RDDApiMatcher(is_serverless=is_serverless), SparkSqlContextMatcher(is_serverless=is_serverless), LoggingMatcher(is_serverless=is_serverless), + UDFMatcher(is_serverless=is_serverless, dbr_version=dbr_version), + CommandContextMatcher(is_serverless=is_serverless), ] + if dbr_version and dbr_version < (14, 3): + self._matchers.append(CatalogApiMatcher(is_serverless=is_serverless)) def lint_tree(self, tree: Tree) -> Iterator[Advice]: for matcher in self._matchers: diff --git a/tests/unit/source_code/samples/functional/spark-connect/catalog-api_13_3.py b/tests/unit/source_code/samples/functional/spark-connect/catalog-api_13_3.py new file mode 100644 index 0000000000..b928f2e572 --- /dev/null +++ b/tests/unit/source_code/samples/functional/spark-connect/catalog-api_13_3.py @@ -0,0 +1,11 @@ +# ucx[catalog-api-in-shared-clusters:+1:0:+1:13] spark.catalog functions require DBR 14.3 LTS or above on UC Shared Clusters +spark.catalog.tableExists("table") +# ucx[catalog-api-in-shared-clusters:+1:0:+1:13] spark.catalog functions require DBR 14.3 LTS or above on UC Shared Clusters +spark.catalog.listDatabases() + + +def catalog(): + pass + + +catalog() diff --git a/tests/unit/source_code/samples/functional/spark-connect/catalog-api_14_3.py b/tests/unit/source_code/samples/functional/spark-connect/catalog-api_14_3.py new file mode 100644 index 0000000000..9fced8bca4 --- /dev/null +++ b/tests/unit/source_code/samples/functional/spark-connect/catalog-api_14_3.py @@ -0,0 +1,9 @@ +spark.catalog.tableExists("table") +spark.catalog.listDatabases() + + +def catalog(): + pass + + +catalog() diff --git a/tests/unit/source_code/samples/functional/spark-connect/command-context.py b/tests/unit/source_code/samples/functional/spark-connect/command-context.py new file mode 100644 index 0000000000..18cba493ab --- /dev/null +++ b/tests/unit/source_code/samples/functional/spark-connect/command-context.py @@ -0,0 +1,6 @@ +# ucx[toJson-in-shared-clusters:+1:6:+1:80] toJson() is not available on UC Shared Clusters. Use toSafeJson() on DBR 13.3 LTS or above to get a subset of command context information. +print(dbutils.notebook.entry_point.getDbutils().notebook().getContext().toJson()) +dbutils.notebook.entry_point.getDbutils().notebook().getContext().toSafeJson() +notebook = dbutils.notebook.entry_point.getDbutils().notebook() +# ucx[toJson-in-shared-clusters:+1:0:+1:30] toJson() is not available on UC Shared Clusters. Use toSafeJson() on DBR 13.3 LTS or above to get a subset of command context information. +notebook.getContext().toJson() diff --git a/tests/unit/source_code/samples/functional/jvm-access.py b/tests/unit/source_code/samples/functional/spark-connect/jvm-access.py similarity index 100% rename from tests/unit/source_code/samples/functional/jvm-access.py rename to tests/unit/source_code/samples/functional/spark-connect/jvm-access.py diff --git a/tests/unit/source_code/samples/functional/spark-connect/python-udfs_13_3.py b/tests/unit/source_code/samples/functional/spark-connect/python-udfs_13_3.py new file mode 100644 index 0000000000..62841488e0 --- /dev/null +++ b/tests/unit/source_code/samples/functional/spark-connect/python-udfs_13_3.py @@ -0,0 +1,52 @@ +from pyspark.sql.functions import udf, udtf, lit +import pandas as pd + + +@udf(returnType='int') +def slen(s): + return len(s) + + +# ucx[python-udf-in-shared-clusters:+1:1:+1:37] Arrow UDFs require DBR 14.3 LTS or above on UC Shared Clusters +@udf(returnType='int', useArrow=True) +def arrow_slen(s): + return len(s) + + +df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) +df.select(slen("name"), arrow_slen("name")).show() + +slen1 = udf(lambda s: len(s), returnType='int') +# ucx[python-udf-in-shared-clusters:+1:14:+1:68] Arrow UDFs require DBR 14.3 LTS or above on UC Shared Clusters +arrow_slen1 = udf(lambda s: len(s), returnType='int', useArrow=True) + +df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) + +df.select(slen1("name"), arrow_slen1("name")).show() + +df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) + + +def subtract_mean(pdf: pd.DataFrame) -> pd.DataFrame: + v = pdf.v + return pdf.assign(v=v - v.mean()) + + +# ucx[python-udf-in-shared-clusters:+1:0:+1:73] applyInPandas require DBR 14.3 LTS or above on UC Shared Clusters +df.groupby("id").applyInPandas(subtract_mean, schema="id long, v double").show() + + +class SquareNumbers: + def eval(self, start: int, end: int): + for num in range(start, end + 1): + yield (num, num * num) + + +# ucx[python-udf-in-shared-clusters:+1:13:+1:69] udtf require DBR 14.3 LTS or above on UC Shared Clusters +square_num = udtf(SquareNumbers, returnType="num: int, squared: int") +square_num(lit(1), lit(3)).show() + +from pyspark.sql.types import IntegerType + +# ucx[python-udf-in-shared-clusters:+1:0:+1:73] Cannot register Java UDF from Python code on UC Shared Clusters. Use a %scala cell to register the Scala UDF using spark.udf.register. +spark.udf.registerJavaFunction("func", "org.example.func", IntegerType()) diff --git a/tests/unit/source_code/samples/functional/spark-connect/python-udfs_14_3.py b/tests/unit/source_code/samples/functional/spark-connect/python-udfs_14_3.py new file mode 100644 index 0000000000..6de02266f0 --- /dev/null +++ b/tests/unit/source_code/samples/functional/spark-connect/python-udfs_14_3.py @@ -0,0 +1,38 @@ +from pyspark.sql.functions import udf, udtf, lit +import pandas as pd + + +@udf(returnType='int') +def slen(s): + return len(s) + + +@udf(returnType='int', useArrow=True) +def arrow_slen(s): + return len(s) + + +df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) +df.select(slen("name"), arrow_slen("name")).show() + +slen1 = udf(lambda s: len(s), returnType='int') +arrow_slen1 = udf(lambda s: len(s), returnType='int', useArrow=True) + +df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) + +df.select(slen1("name"), arrow_slen1("name")).show() + +df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) + + +def subtract_mean(pdf: pd.DataFrame) -> pd.DataFrame: + v = pdf.v + return pdf.assign(v=v - v.mean()) + + +df.groupby("id").applyInPandas(subtract_mean, schema="id long, v double").show() + +from pyspark.sql.types import IntegerType + +# ucx[python-udf-in-shared-clusters:+1:0:+1:73] Cannot register Java UDF from Python code on UC Shared Clusters. Use a %scala cell to register the Scala UDF using spark.udf.register. +spark.udf.registerJavaFunction("func", "org.example.func", IntegerType()) diff --git a/tests/unit/source_code/samples/functional/rdd-apis.py b/tests/unit/source_code/samples/functional/spark-connect/rdd-apis.py similarity index 100% rename from tests/unit/source_code/samples/functional/rdd-apis.py rename to tests/unit/source_code/samples/functional/spark-connect/rdd-apis.py diff --git a/tests/unit/source_code/samples/functional/spark-logging.py b/tests/unit/source_code/samples/functional/spark-connect/spark-logging.py similarity index 100% rename from tests/unit/source_code/samples/functional/spark-logging.py rename to tests/unit/source_code/samples/functional/spark-connect/spark-logging.py diff --git a/tests/unit/source_code/test_functional.py b/tests/unit/source_code/test_functional.py index a81ca8be45..b0bf06ea80 100644 --- a/tests/unit/source_code/test_functional.py +++ b/tests/unit/source_code/test_functional.py @@ -62,6 +62,13 @@ class Functional: ) _location = Path(__file__).parent / 'samples/functional' + TEST_DBR_VERSION = { + 'python-udfs_13_3.py': (13, 3), + 'catalog-api_13_3.py': (13, 3), + 'python-udfs_14_3.py': (14, 3), + 'catalog-api_14_3.py': (14, 3), + } + @classmethod def all(cls) -> list['Functional']: return [Functional(_) for _ in cls._location.glob('**/*.py')] @@ -104,7 +111,9 @@ def _lint(self) -> Iterable[Advice]: ) session_state = CurrentSessionState() session_state.named_parameters = {"my-widget": "my-path.py"} - ctx = LinterContext(migration_index, session_state) + ctx = LinterContext( + migration_index, session_state, dbr_version=Functional.TEST_DBR_VERSION.get(self.path.name, None) + ) linter = FileLinter(ctx, self.path) return linter.lint() From 8371ea53e5ed768edfd669d93f89f86d8a97faf4 Mon Sep 17 00:00:00 2001 From: Vsevolod Stepanov Date: Mon, 8 Jul 2024 17:53:44 +0200 Subject: [PATCH 2/3] Address review feedback --- src/databricks/labs/ucx/source_code/base.py | 14 ++++ src/databricks/labs/ucx/source_code/jobs.py | 4 ++ .../labs/ucx/source_code/linters/context.py | 12 +--- .../ucx/source_code/linters/spark_connect.py | 34 +++++---- .../source_code/linters/test_spark_connect.py | 22 +++--- .../spark-connect/catalog-api_13_3.py | 9 +++ .../spark-connect/catalog-api_14_3.py | 1 + .../spark-connect/command-context.py | 4 +- .../spark-connect/python-udfs_13_3.py | 1 + .../spark-connect/python-udfs_14_3.py | 1 + tests/unit/source_code/test_functional.py | 72 +++++++++++-------- 11 files changed, 103 insertions(+), 71 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/base.py b/src/databricks/labs/ucx/source_code/base.py index ccd1a1fa7d..8704541a91 100644 --- a/src/databricks/labs/ucx/source_code/base.py +++ b/src/databricks/labs/ucx/source_code/base.py @@ -172,6 +172,20 @@ class CurrentSessionState: spark_conf: dict[str, str] | None = None named_parameters: dict[str, str] | None = None data_security_mode: compute.DataSecurityMode | None = None + is_serverless: bool = False + dbr_version: tuple[int, int] | None = None + + @classmethod + def from_json(cls, json: dict) -> CurrentSessionState: + return cls( + schema=json.get('schema', DEFAULT_SCHEMA), + catalog=json.get('catalog', DEFAULT_CATALOG), + spark_conf=json.get('spark_conf', None), + named_parameters=json.get('named_parameters', None), + data_security_mode=json.get('data_security_mode', None), + is_serverless=json.get('is_serverless', False), + dbr_version=tuple(json['dbr_version']) if 'dbr_version' in json else None, + ) class SequentialLinter(Linter): diff --git a/src/databricks/labs/ucx/source_code/jobs.py b/src/databricks/labs/ucx/source_code/jobs.py index aa82e9557f..3e2c800ad1 100644 --- a/src/databricks/labs/ucx/source_code/jobs.py +++ b/src/databricks/labs/ucx/source_code/jobs.py @@ -76,6 +76,7 @@ def __init__(self, ws: WorkspaceClient, task: jobs.Task, job: jobs.Job): self._spark_conf: dict[str, str] | None = {} self._spark_version: str | None = None self._data_security_mode = None + self._is_serverless = False @property def named_parameters(self) -> dict[str, str]: @@ -268,6 +269,8 @@ def _register_cluster_info(self): if job_cluster.job_cluster_key != self._task.job_cluster_key: continue return self._new_job_cluster_metadata(job_cluster.new_cluster) + self._data_security_mode = compute.DataSecurityMode.USER_ISOLATION + self._is_serverless = True return [] def _new_job_cluster_metadata(self, new_cluster): @@ -357,6 +360,7 @@ def _lint_task(self, task: jobs.Task, job: jobs.Job): data_security_mode=container.data_security_mode, named_parameters=container.named_parameters, spark_conf=container.spark_conf, + dbr_version=container.runtime_version, ) graph = DependencyGraph(dependency, None, self._resolver, self._path_lookup, session_state) problems = container.build_dependency_graph(graph) diff --git a/src/databricks/labs/ucx/source_code/linters/context.py b/src/databricks/labs/ucx/source_code/linters/context.py index 7303b4785e..41b9212c7f 100644 --- a/src/databricks/labs/ucx/source_code/linters/context.py +++ b/src/databricks/labs/ucx/source_code/linters/context.py @@ -21,13 +21,7 @@ class LinterContext: - def __init__( - self, - index: MigrationIndex | None = None, - session_state: CurrentSessionState | None = None, - dbr_version: tuple[int, int] | None = None, - is_serverless: bool = False, - ): + def __init__(self, index: MigrationIndex | None = None, session_state: CurrentSessionState | None = None): self._index = index session_state = CurrentSessionState() if not session_state else session_state @@ -47,8 +41,8 @@ def __init__( python_linters += [ DBFSUsageLinter(session_state), - DBRv8d0Linter(dbr_version=dbr_version), - SparkConnectLinter(dbr_version=dbr_version, is_serverless=is_serverless), + DBRv8d0Linter(dbr_version=session_state.dbr_version), + SparkConnectLinter(session_state), DbutilsLinter(session_state), ] sql_linters.append(FromDbfsFolder()) diff --git a/src/databricks/labs/ucx/source_code/linters/spark_connect.py b/src/databricks/labs/ucx/source_code/linters/spark_connect.py index 696dfea421..7ff2dbd481 100644 --- a/src/databricks/labs/ucx/source_code/linters/spark_connect.py +++ b/src/databricks/labs/ucx/source_code/linters/spark_connect.py @@ -7,16 +7,17 @@ Advice, Failure, PythonLinter, + CurrentSessionState, ) from databricks.labs.ucx.source_code.linters.python_ast import Tree @dataclass class SharedClusterMatcher: - is_serverless: bool + session_state: CurrentSessionState def _cluster_type_str(self) -> str: - return 'UC Shared Clusters' if not self.is_serverless else 'Serverless Compute' + return 'UC Shared Clusters' if not self.session_state.is_serverless else 'Serverless Compute' @abstractmethod def lint(self, node: NodeNG) -> Iterator[Advice]: @@ -172,12 +173,9 @@ def _match_jvm_log(self, node: NodeNG) -> Iterator[Advice]: ) -@dataclass class UDFMatcher(SharedClusterMatcher): _DBR_14_2_BELOW_NOT_SUPPORTED = ["applyInPandas", "mapInPandas", "applyInPandasWithState", "udtf", "pandas_udf"] - dbr_version: tuple[int, int] | None - def lint(self, node: NodeNG) -> Iterator[Advice]: if not isinstance(node, Call): return @@ -193,8 +191,8 @@ def lint(self, node: NodeNG) -> Iterator[Advice]: if ( function_name in UDFMatcher._DBR_14_2_BELOW_NOT_SUPPORTED - and self.dbr_version - and self.dbr_version < (14, 3) + and self.session_state.dbr_version + and self.session_state.dbr_version < (14, 3) ): yield Failure.from_node( code='python-udf-in-shared-clusters', @@ -202,7 +200,7 @@ def lint(self, node: NodeNG) -> Iterator[Advice]: node=node, ) - if function_name == 'udf' and self.dbr_version and self.dbr_version < (14, 3): + if function_name == 'udf' and self.session_state.dbr_version and self.session_state.dbr_version < (14, 3): for keyword in node.keywords: if keyword.arg == 'useArrow' and isinstance(keyword.value, Const) and keyword.value.value: yield Failure.from_node( @@ -231,7 +229,7 @@ def lint(self, node: NodeNG) -> Iterator[Advice]: function_name = Tree.get_full_function_name(node) if function_name and function_name.endswith('getContext.toJson'): yield Failure.from_node( - code='toJson-in-shared-clusters', + code='to-json-in-shared-clusters', message=f'toJson() is not available on {self._cluster_type_str()}. ' f'Use toSafeJson() on DBR 13.3 LTS or above to get a subset of command context information.', node=node, @@ -239,17 +237,17 @@ def lint(self, node: NodeNG) -> Iterator[Advice]: class SparkConnectLinter(PythonLinter): - def __init__(self, dbr_version: tuple[int, int] | None = None, is_serverless: bool = False): + def __init__(self, session_state: CurrentSessionState): self._matchers = [ - JvmAccessMatcher(is_serverless=is_serverless), - RDDApiMatcher(is_serverless=is_serverless), - SparkSqlContextMatcher(is_serverless=is_serverless), - LoggingMatcher(is_serverless=is_serverless), - UDFMatcher(is_serverless=is_serverless, dbr_version=dbr_version), - CommandContextMatcher(is_serverless=is_serverless), + JvmAccessMatcher(session_state=session_state), + RDDApiMatcher(session_state=session_state), + SparkSqlContextMatcher(session_state=session_state), + LoggingMatcher(session_state=session_state), + UDFMatcher(session_state=session_state), + CommandContextMatcher(session_state=session_state), ] - if dbr_version and dbr_version < (14, 3): - self._matchers.append(CatalogApiMatcher(is_serverless=is_serverless)) + if session_state.dbr_version and session_state.dbr_version < (14, 3): + self._matchers.append(CatalogApiMatcher(session_state=session_state)) def lint_tree(self, tree: Tree) -> Iterator[Advice]: for matcher in self._matchers: diff --git a/tests/unit/source_code/linters/test_spark_connect.py b/tests/unit/source_code/linters/test_spark_connect.py index d33a64117f..b9a6ed8a73 100644 --- a/tests/unit/source_code/linters/test_spark_connect.py +++ b/tests/unit/source_code/linters/test_spark_connect.py @@ -1,12 +1,12 @@ from itertools import chain -from databricks.labs.ucx.source_code.base import Failure +from databricks.labs.ucx.source_code.base import Failure, CurrentSessionState from databricks.labs.ucx.source_code.linters.python_ast import Tree from databricks.labs.ucx.source_code.linters.spark_connect import LoggingMatcher, SparkConnectLinter def test_jvm_access_match_shared(): - linter = SparkConnectLinter(is_serverless=False) + linter = SparkConnectLinter(CurrentSessionState()) code = """ spark.range(10).collect() spark._jspark._jvm.com.my.custom.Name() @@ -26,7 +26,7 @@ def test_jvm_access_match_shared(): def test_jvm_access_match_serverless(): - linter = SparkConnectLinter(is_serverless=True) + linter = SparkConnectLinter(CurrentSessionState(is_serverless=True)) code = """ spark.range(10).collect() spark._jspark._jvm.com.my.custom.Name() @@ -47,7 +47,7 @@ def test_jvm_access_match_serverless(): def test_rdd_context_match_shared(): - linter = SparkConnectLinter(is_serverless=False) + linter = SparkConnectLinter(CurrentSessionState()) code = """ rdd1 = sc.parallelize([1, 2, 3]) rdd2 = spark.createDataFrame(sc.emptyRDD(), schema) @@ -91,7 +91,7 @@ def test_rdd_context_match_shared(): def test_rdd_context_match_serverless(): - linter = SparkConnectLinter(is_serverless=True) + linter = SparkConnectLinter(CurrentSessionState(is_serverless=True)) code = """ rdd1 = sc.parallelize([1, 2, 3]) rdd2 = spark.createDataFrame(sc.emptyRDD(), schema) @@ -133,7 +133,7 @@ def test_rdd_context_match_serverless(): def test_rdd_map_partitions(): - linter = SparkConnectLinter(is_serverless=False) + linter = SparkConnectLinter(CurrentSessionState()) code = """ df = spark.createDataFrame([]) df.rdd.mapPartitions(myUdf) @@ -153,7 +153,7 @@ def test_rdd_map_partitions(): def test_conf_shared(): - linter = SparkConnectLinter(is_serverless=False) + linter = SparkConnectLinter(CurrentSessionState()) code = """df.sparkContext.getConf().get('spark.my.conf')""" assert [ Failure( @@ -168,7 +168,7 @@ def test_conf_shared(): def test_conf_serverless(): - linter = SparkConnectLinter(is_serverless=True) + linter = SparkConnectLinter(CurrentSessionState(is_serverless=True)) code = """sc._conf().get('spark.my.conf')""" expected = [ Failure( @@ -185,7 +185,7 @@ def test_conf_serverless(): def test_logging_shared(): - logging_matcher = LoggingMatcher(is_serverless=False) + logging_matcher = LoggingMatcher(CurrentSessionState()) code = """ sc.setLogLevel("INFO") setLogLevel("WARN") @@ -226,7 +226,7 @@ def test_logging_shared(): def test_logging_serverless(): - logging_matcher = LoggingMatcher(is_serverless=True) + logging_matcher = LoggingMatcher(CurrentSessionState(is_serverless=True)) code = """ sc.setLogLevel("INFO") log4jLogger = sc._jvm.org.apache.log4j @@ -255,7 +255,7 @@ def test_logging_serverless(): def test_valid_code(): - linter = SparkConnectLinter() + linter = SparkConnectLinter(CurrentSessionState()) code = """ df = spark.range(10) df.collect() diff --git a/tests/unit/source_code/samples/functional/spark-connect/catalog-api_13_3.py b/tests/unit/source_code/samples/functional/spark-connect/catalog-api_13_3.py index b928f2e572..5f44658eba 100644 --- a/tests/unit/source_code/samples/functional/spark-connect/catalog-api_13_3.py +++ b/tests/unit/source_code/samples/functional/spark-connect/catalog-api_13_3.py @@ -1,3 +1,4 @@ +# ucx[session-state] {"dbr_version": [13, 3]} # ucx[catalog-api-in-shared-clusters:+1:0:+1:13] spark.catalog functions require DBR 14.3 LTS or above on UC Shared Clusters spark.catalog.tableExists("table") # ucx[catalog-api-in-shared-clusters:+1:0:+1:13] spark.catalog functions require DBR 14.3 LTS or above on UC Shared Clusters @@ -9,3 +10,11 @@ def catalog(): catalog() + +class Fatalog: + def tableExists(self, x): ... +class Foo: + def catalog(self): Fatalog() + +x = Foo() +x.catalog.tableExists("...") diff --git a/tests/unit/source_code/samples/functional/spark-connect/catalog-api_14_3.py b/tests/unit/source_code/samples/functional/spark-connect/catalog-api_14_3.py index 9fced8bca4..08e42e5a39 100644 --- a/tests/unit/source_code/samples/functional/spark-connect/catalog-api_14_3.py +++ b/tests/unit/source_code/samples/functional/spark-connect/catalog-api_14_3.py @@ -1,3 +1,4 @@ +# ucx[session-state] {"dbr_version": [14, 3]} spark.catalog.tableExists("table") spark.catalog.listDatabases() diff --git a/tests/unit/source_code/samples/functional/spark-connect/command-context.py b/tests/unit/source_code/samples/functional/spark-connect/command-context.py index 18cba493ab..4958f6af4a 100644 --- a/tests/unit/source_code/samples/functional/spark-connect/command-context.py +++ b/tests/unit/source_code/samples/functional/spark-connect/command-context.py @@ -1,6 +1,6 @@ -# ucx[toJson-in-shared-clusters:+1:6:+1:80] toJson() is not available on UC Shared Clusters. Use toSafeJson() on DBR 13.3 LTS or above to get a subset of command context information. +# ucx[to-json-in-shared-clusters:+1:6:+1:80] toJson() is not available on UC Shared Clusters. Use toSafeJson() on DBR 13.3 LTS or above to get a subset of command context information. print(dbutils.notebook.entry_point.getDbutils().notebook().getContext().toJson()) dbutils.notebook.entry_point.getDbutils().notebook().getContext().toSafeJson() notebook = dbutils.notebook.entry_point.getDbutils().notebook() -# ucx[toJson-in-shared-clusters:+1:0:+1:30] toJson() is not available on UC Shared Clusters. Use toSafeJson() on DBR 13.3 LTS or above to get a subset of command context information. +# ucx[to-json-in-shared-clusters:+1:0:+1:30] toJson() is not available on UC Shared Clusters. Use toSafeJson() on DBR 13.3 LTS or above to get a subset of command context information. notebook.getContext().toJson() diff --git a/tests/unit/source_code/samples/functional/spark-connect/python-udfs_13_3.py b/tests/unit/source_code/samples/functional/spark-connect/python-udfs_13_3.py index 62841488e0..21eae64c1a 100644 --- a/tests/unit/source_code/samples/functional/spark-connect/python-udfs_13_3.py +++ b/tests/unit/source_code/samples/functional/spark-connect/python-udfs_13_3.py @@ -1,3 +1,4 @@ +# ucx[session-state] {"dbr_version": [13, 3]} from pyspark.sql.functions import udf, udtf, lit import pandas as pd diff --git a/tests/unit/source_code/samples/functional/spark-connect/python-udfs_14_3.py b/tests/unit/source_code/samples/functional/spark-connect/python-udfs_14_3.py index 6de02266f0..2dc3cd4622 100644 --- a/tests/unit/source_code/samples/functional/spark-connect/python-udfs_14_3.py +++ b/tests/unit/source_code/samples/functional/spark-connect/python-udfs_14_3.py @@ -1,3 +1,4 @@ +# ucx[session-state] {"dbr_version": [14, 3]} from pyspark.sql.functions import udf, udtf, lit import pandas as pd diff --git a/tests/unit/source_code/test_functional.py b/tests/unit/source_code/test_functional.py index b0bf06ea80..52a59ba6e0 100644 --- a/tests/unit/source_code/test_functional.py +++ b/tests/unit/source_code/test_functional.py @@ -1,11 +1,13 @@ from __future__ import annotations +import json import re import tokenize from collections.abc import Iterable, Generator from dataclasses import dataclass from pathlib import Path +from typing import Any import pytest @@ -60,14 +62,9 @@ class Functional: _re = re.compile( r"# ucx\[(?P[\w-]+):(?P[\d+]+):(?P[\d]+):(?P[\d+]+):(?P[\d]+)] (?P.*)" ) - _location = Path(__file__).parent / 'samples/functional' + _re_session_state = re.compile(r'# ucx\[session-state] (?P\{.*})') - TEST_DBR_VERSION = { - 'python-udfs_13_3.py': (13, 3), - 'catalog-api_13_3.py': (13, 3), - 'python-udfs_14_3.py': (14, 3), - 'catalog-api_14_3.py': (14, 3), - } + _location = Path(__file__).parent / 'samples/functional' @classmethod def all(cls) -> list['Functional']: @@ -109,41 +106,54 @@ def _lint(self) -> Iterable[Advice]: MigrationStatus('other', 'matters', dst_catalog='some', dst_schema='certain', dst_table='issues'), ] ) - session_state = CurrentSessionState() + session_state = self._test_session_state() + print(str(session_state)) session_state.named_parameters = {"my-widget": "my-path.py"} - ctx = LinterContext( - migration_index, session_state, dbr_version=Functional.TEST_DBR_VERSION.get(self.path.name, None) - ) + ctx = LinterContext(migration_index, session_state) linter = FileLinter(ctx, self.path) return linter.lint() - def _expected_problems(self) -> Generator[Expectation, None, None]: + def _regex_match(self, regex: re.Pattern[str]) -> Generator[tuple[Comment, dict[str, Any]], None, None]: with self.path.open('rb') as f: for comment in self._comments(f): if not comment.text.startswith('# ucx['): continue - match = self._re.match(comment.text) + match = regex.match(comment.text) if not match: continue groups = match.groupdict() - reported_start_line = groups['start_line'] - if '+' in reported_start_line: - start_line = int(reported_start_line[1:]) + comment.start_line - else: - start_line = int(reported_start_line) - reported_end_line = groups['end_line'] - if '+' in reported_end_line: - end_line = int(reported_end_line[1:]) + comment.start_line - else: - end_line = int(reported_end_line) - yield Expectation( - code=groups['code'], - message=groups['message'], - start_line=start_line, - start_col=int(groups['start_col']), - end_line=end_line, - end_col=int(groups['end_col']), - ) + yield comment, groups + + def _expected_problems(self) -> Generator[Expectation, None, None]: + for comment, groups in self._regex_match(self._re): + reported_start_line = groups['start_line'] + if '+' in reported_start_line: + start_line = int(reported_start_line[1:]) + comment.start_line + else: + start_line = int(reported_start_line) + reported_end_line = groups['end_line'] + if '+' in reported_end_line: + end_line = int(reported_end_line[1:]) + comment.start_line + else: + end_line = int(reported_end_line) + yield Expectation( + code=groups['code'], + message=groups['message'], + start_line=start_line, + start_col=int(groups['start_col']), + end_line=end_line, + end_col=int(groups['end_col']), + ) + + def _test_session_state(self) -> CurrentSessionState: + matches = list(self._regex_match(self._re_session_state)) + if len(matches) > 1: + raise ValueError("A test should have no more than one session state definition") + if len(matches) == 0: + return CurrentSessionState() + groups = matches[0][1] + json_str = groups['session_state_json'] + return CurrentSessionState.from_json(json.loads(json_str)) @staticmethod def _comments(f) -> Generator[Comment, None, None]: From 132d99a7d22301967aeed8934fe0ef4102c846e0 Mon Sep 17 00:00:00 2001 From: Vsevolod Stepanov Date: Mon, 8 Jul 2024 18:28:20 +0200 Subject: [PATCH 3/3] fmt --- .../samples/functional/spark-connect/catalog-api_13_3.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/unit/source_code/samples/functional/spark-connect/catalog-api_13_3.py b/tests/unit/source_code/samples/functional/spark-connect/catalog-api_13_3.py index 5f44658eba..1721f2a4c3 100644 --- a/tests/unit/source_code/samples/functional/spark-connect/catalog-api_13_3.py +++ b/tests/unit/source_code/samples/functional/spark-connect/catalog-api_13_3.py @@ -11,10 +11,13 @@ def catalog(): catalog() + class Fatalog: - def tableExists(self, x): ... + def tableExists(self, x): ... class Foo: - def catalog(self): Fatalog() + def catalog(self): + Fatalog() + x = Foo() x.catalog.tableExists("...")