Skip to content

Commit

Permalink
Add more checks for spark-connect linter (#2092)
Browse files Browse the repository at this point in the history
## Changes
Add more checks to detect code incompatibilities with UC Shared
Clusters:
- use of Python UDF unsupported eval types
- spark.catalog.X APIs on DBR < 14.3
- use of commandContext

### Tests
- [ ] manually tested
- [x] added unit tests
- [ ] added integration tests
- [ ] verified on staging environment (screenshot attached)
  • Loading branch information
vsevolodstep-db authored Jul 8, 2024
1 parent b2a2ae4 commit b564fe3
Show file tree
Hide file tree
Showing 15 changed files with 288 additions and 42 deletions.
14 changes: 14 additions & 0 deletions src/databricks/labs/ucx/source_code/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ class CurrentSessionState:
spark_conf: dict[str, str] | None = None
named_parameters: dict[str, str] | None = None
data_security_mode: compute.DataSecurityMode | None = None
is_serverless: bool = False
dbr_version: tuple[int, int] | None = None

@classmethod
def from_json(cls, json: dict) -> CurrentSessionState:
return cls(
schema=json.get('schema', DEFAULT_SCHEMA),
catalog=json.get('catalog', DEFAULT_CATALOG),
spark_conf=json.get('spark_conf', None),
named_parameters=json.get('named_parameters', None),
data_security_mode=json.get('data_security_mode', None),
is_serverless=json.get('is_serverless', False),
dbr_version=tuple(json['dbr_version']) if 'dbr_version' in json else None,
)


class SequentialLinter(Linter):
Expand Down
4 changes: 4 additions & 0 deletions src/databricks/labs/ucx/source_code/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self, ws: WorkspaceClient, task: jobs.Task, job: jobs.Job):
self._spark_conf: dict[str, str] | None = {}
self._spark_version: str | None = None
self._data_security_mode = None
self._is_serverless = False

@property
def named_parameters(self) -> dict[str, str]:
Expand Down Expand Up @@ -268,6 +269,8 @@ def _register_cluster_info(self):
if job_cluster.job_cluster_key != self._task.job_cluster_key:
continue
return self._new_job_cluster_metadata(job_cluster.new_cluster)
self._data_security_mode = compute.DataSecurityMode.USER_ISOLATION
self._is_serverless = True
return []

def _new_job_cluster_metadata(self, new_cluster):
Expand Down Expand Up @@ -357,6 +360,7 @@ def _lint_task(self, task: jobs.Task, job: jobs.Job):
data_security_mode=container.data_security_mode,
named_parameters=container.named_parameters,
spark_conf=container.spark_conf,
dbr_version=container.runtime_version,
)
graph = DependencyGraph(dependency, None, self._resolver, self._path_lookup, session_state)
problems = container.build_dependency_graph(graph)
Expand Down
4 changes: 2 additions & 2 deletions src/databricks/labs/ucx/source_code/linters/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def __init__(self, index: MigrationIndex | None = None, session_state: CurrentSe

python_linters += [
DBFSUsageLinter(session_state),
DBRv8d0Linter(dbr_version=None),
SparkConnectLinter(is_serverless=False),
DBRv8d0Linter(dbr_version=session_state.dbr_version),
SparkConnectLinter(session_state),
DbutilsLinter(session_state),
]
sql_linters.append(FromDbfsFolder())
Expand Down
10 changes: 10 additions & 0 deletions src/databricks/labs/ucx/source_code/linters/python_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,16 @@ def __repr__(self):
def get_full_attribute_name(cls, node: Attribute) -> str:
return cls._get_attribute_value(node)

@classmethod
def get_function_name(cls, node: Call) -> str | None:
if not isinstance(node, Call):
return None
if isinstance(node.func, Attribute):
return node.func.attrname
if isinstance(node.func, Name):
return node.func.name
return None

@classmethod
def get_full_function_name(cls, node: Call) -> str | None:
if not isinstance(node, Call):
Expand Down
84 changes: 76 additions & 8 deletions src/databricks/labs/ucx/source_code/linters/spark_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@
from collections.abc import Iterator
from dataclasses import dataclass

from astroid import Attribute, Call, Name, NodeNG # type: ignore
from astroid import Attribute, Call, Const, Name, NodeNG # type: ignore
from databricks.labs.ucx.source_code.base import (
Advice,
Failure,
PythonLinter,
CurrentSessionState,
)
from databricks.labs.ucx.source_code.linters.python_ast import Tree


@dataclass
class SharedClusterMatcher:
is_serverless: bool
session_state: CurrentSessionState

def _cluster_type_str(self) -> str:
return 'UC Shared Clusters' if not self.is_serverless else 'Serverless Compute'
return 'UC Shared Clusters' if not self.session_state.is_serverless else 'Serverless Compute'

@abstractmethod
def lint(self, node: NodeNG) -> Iterator[Advice]:
Expand Down Expand Up @@ -172,14 +173,81 @@ def _match_jvm_log(self, node: NodeNG) -> Iterator[Advice]:
)


class UDFMatcher(SharedClusterMatcher):
_DBR_14_2_BELOW_NOT_SUPPORTED = ["applyInPandas", "mapInPandas", "applyInPandasWithState", "udtf", "pandas_udf"]

def lint(self, node: NodeNG) -> Iterator[Advice]:
if not isinstance(node, Call):
return
function_name = Tree.get_function_name(node)

if function_name == 'registerJavaFunction':
yield Failure.from_node(
code='python-udf-in-shared-clusters',
message=f'Cannot register Java UDF from Python code on {self._cluster_type_str()}. '
f'Use a %scala cell to register the Scala UDF using spark.udf.register.',
node=node,
)

if (
function_name in UDFMatcher._DBR_14_2_BELOW_NOT_SUPPORTED
and self.session_state.dbr_version
and self.session_state.dbr_version < (14, 3)
):
yield Failure.from_node(
code='python-udf-in-shared-clusters',
message=f'{function_name} require DBR 14.3 LTS or above on {self._cluster_type_str()}',
node=node,
)

if function_name == 'udf' and self.session_state.dbr_version and self.session_state.dbr_version < (14, 3):
for keyword in node.keywords:
if keyword.arg == 'useArrow' and isinstance(keyword.value, Const) and keyword.value.value:
yield Failure.from_node(
code='python-udf-in-shared-clusters',
message=f'Arrow UDFs require DBR 14.3 LTS or above on {self._cluster_type_str()}',
node=node,
)


class CatalogApiMatcher(SharedClusterMatcher):
def lint(self, node: NodeNG) -> Iterator[Advice]:
if not isinstance(node, Attribute):
return
if node.attrname == 'catalog' and Tree.get_full_attribute_name(node).endswith('spark.catalog'):
yield Failure.from_node(
code='catalog-api-in-shared-clusters',
message=f'spark.catalog functions require DBR 14.3 LTS or above on {self._cluster_type_str()}',
node=node,
)


class CommandContextMatcher(SharedClusterMatcher):
def lint(self, node: NodeNG) -> Iterator[Advice]:
if not isinstance(node, Call):
return
function_name = Tree.get_full_function_name(node)
if function_name and function_name.endswith('getContext.toJson'):
yield Failure.from_node(
code='to-json-in-shared-clusters',
message=f'toJson() is not available on {self._cluster_type_str()}. '
f'Use toSafeJson() on DBR 13.3 LTS or above to get a subset of command context information.',
node=node,
)


class SparkConnectLinter(PythonLinter):
def __init__(self, is_serverless: bool = False):
def __init__(self, session_state: CurrentSessionState):
self._matchers = [
JvmAccessMatcher(is_serverless=is_serverless),
RDDApiMatcher(is_serverless=is_serverless),
SparkSqlContextMatcher(is_serverless=is_serverless),
LoggingMatcher(is_serverless=is_serverless),
JvmAccessMatcher(session_state=session_state),
RDDApiMatcher(session_state=session_state),
SparkSqlContextMatcher(session_state=session_state),
LoggingMatcher(session_state=session_state),
UDFMatcher(session_state=session_state),
CommandContextMatcher(session_state=session_state),
]
if session_state.dbr_version and session_state.dbr_version < (14, 3):
self._matchers.append(CatalogApiMatcher(session_state=session_state))

def lint_tree(self, tree: Tree) -> Iterator[Advice]:
for matcher in self._matchers:
Expand Down
22 changes: 11 additions & 11 deletions tests/unit/source_code/linters/test_spark_connect.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from itertools import chain

from databricks.labs.ucx.source_code.base import Failure
from databricks.labs.ucx.source_code.base import Failure, CurrentSessionState
from databricks.labs.ucx.source_code.linters.python_ast import Tree
from databricks.labs.ucx.source_code.linters.spark_connect import LoggingMatcher, SparkConnectLinter


def test_jvm_access_match_shared():
linter = SparkConnectLinter(is_serverless=False)
linter = SparkConnectLinter(CurrentSessionState())
code = """
spark.range(10).collect()
spark._jspark._jvm.com.my.custom.Name()
Expand All @@ -26,7 +26,7 @@ def test_jvm_access_match_shared():


def test_jvm_access_match_serverless():
linter = SparkConnectLinter(is_serverless=True)
linter = SparkConnectLinter(CurrentSessionState(is_serverless=True))
code = """
spark.range(10).collect()
spark._jspark._jvm.com.my.custom.Name()
Expand All @@ -47,7 +47,7 @@ def test_jvm_access_match_serverless():


def test_rdd_context_match_shared():
linter = SparkConnectLinter(is_serverless=False)
linter = SparkConnectLinter(CurrentSessionState())
code = """
rdd1 = sc.parallelize([1, 2, 3])
rdd2 = spark.createDataFrame(sc.emptyRDD(), schema)
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_rdd_context_match_shared():


def test_rdd_context_match_serverless():
linter = SparkConnectLinter(is_serverless=True)
linter = SparkConnectLinter(CurrentSessionState(is_serverless=True))
code = """
rdd1 = sc.parallelize([1, 2, 3])
rdd2 = spark.createDataFrame(sc.emptyRDD(), schema)
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_rdd_context_match_serverless():


def test_rdd_map_partitions():
linter = SparkConnectLinter(is_serverless=False)
linter = SparkConnectLinter(CurrentSessionState())
code = """
df = spark.createDataFrame([])
df.rdd.mapPartitions(myUdf)
Expand All @@ -153,7 +153,7 @@ def test_rdd_map_partitions():


def test_conf_shared():
linter = SparkConnectLinter(is_serverless=False)
linter = SparkConnectLinter(CurrentSessionState())
code = """df.sparkContext.getConf().get('spark.my.conf')"""
assert [
Failure(
Expand All @@ -168,7 +168,7 @@ def test_conf_shared():


def test_conf_serverless():
linter = SparkConnectLinter(is_serverless=True)
linter = SparkConnectLinter(CurrentSessionState(is_serverless=True))
code = """sc._conf().get('spark.my.conf')"""
expected = [
Failure(
Expand All @@ -185,7 +185,7 @@ def test_conf_serverless():


def test_logging_shared():
logging_matcher = LoggingMatcher(is_serverless=False)
logging_matcher = LoggingMatcher(CurrentSessionState())
code = """
sc.setLogLevel("INFO")
setLogLevel("WARN")
Expand Down Expand Up @@ -226,7 +226,7 @@ def test_logging_shared():


def test_logging_serverless():
logging_matcher = LoggingMatcher(is_serverless=True)
logging_matcher = LoggingMatcher(CurrentSessionState(is_serverless=True))
code = """
sc.setLogLevel("INFO")
log4jLogger = sc._jvm.org.apache.log4j
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_logging_serverless():


def test_valid_code():
linter = SparkConnectLinter()
linter = SparkConnectLinter(CurrentSessionState())
code = """
df = spark.range(10)
df.collect()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ucx[session-state] {"dbr_version": [13, 3]}
# ucx[catalog-api-in-shared-clusters:+1:0:+1:13] spark.catalog functions require DBR 14.3 LTS or above on UC Shared Clusters
spark.catalog.tableExists("table")
# ucx[catalog-api-in-shared-clusters:+1:0:+1:13] spark.catalog functions require DBR 14.3 LTS or above on UC Shared Clusters
spark.catalog.listDatabases()


def catalog():
pass


catalog()


class Fatalog:
def tableExists(self, x): ...
class Foo:
def catalog(self):
Fatalog()


x = Foo()
x.catalog.tableExists("...")
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# ucx[session-state] {"dbr_version": [14, 3]}
spark.catalog.tableExists("table")
spark.catalog.listDatabases()


def catalog():
pass


catalog()
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# ucx[to-json-in-shared-clusters:+1:6:+1:80] toJson() is not available on UC Shared Clusters. Use toSafeJson() on DBR 13.3 LTS or above to get a subset of command context information.
print(dbutils.notebook.entry_point.getDbutils().notebook().getContext().toJson())
dbutils.notebook.entry_point.getDbutils().notebook().getContext().toSafeJson()
notebook = dbutils.notebook.entry_point.getDbutils().notebook()
# ucx[to-json-in-shared-clusters:+1:0:+1:30] toJson() is not available on UC Shared Clusters. Use toSafeJson() on DBR 13.3 LTS or above to get a subset of command context information.
notebook.getContext().toJson()
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# ucx[session-state] {"dbr_version": [13, 3]}
from pyspark.sql.functions import udf, udtf, lit
import pandas as pd


@udf(returnType='int')
def slen(s):
return len(s)


# ucx[python-udf-in-shared-clusters:+1:1:+1:37] Arrow UDFs require DBR 14.3 LTS or above on UC Shared Clusters
@udf(returnType='int', useArrow=True)
def arrow_slen(s):
return len(s)


df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
df.select(slen("name"), arrow_slen("name")).show()

slen1 = udf(lambda s: len(s), returnType='int')
# ucx[python-udf-in-shared-clusters:+1:14:+1:68] Arrow UDFs require DBR 14.3 LTS or above on UC Shared Clusters
arrow_slen1 = udf(lambda s: len(s), returnType='int', useArrow=True)

df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))

df.select(slen1("name"), arrow_slen1("name")).show()

df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))


def subtract_mean(pdf: pd.DataFrame) -> pd.DataFrame:
v = pdf.v
return pdf.assign(v=v - v.mean())


# ucx[python-udf-in-shared-clusters:+1:0:+1:73] applyInPandas require DBR 14.3 LTS or above on UC Shared Clusters
df.groupby("id").applyInPandas(subtract_mean, schema="id long, v double").show()


class SquareNumbers:
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num * num)


# ucx[python-udf-in-shared-clusters:+1:13:+1:69] udtf require DBR 14.3 LTS or above on UC Shared Clusters
square_num = udtf(SquareNumbers, returnType="num: int, squared: int")
square_num(lit(1), lit(3)).show()

from pyspark.sql.types import IntegerType

# ucx[python-udf-in-shared-clusters:+1:0:+1:73] Cannot register Java UDF from Python code on UC Shared Clusters. Use a %scala cell to register the Scala UDF using spark.udf.register.
spark.udf.registerJavaFunction("func", "org.example.func", IntegerType())
Loading

0 comments on commit b564fe3

Please sign in to comment.