Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract command codes and unify the checks for spark_conf, cluster_policy, init_scripts #855

Merged
merged 9 commits into from
Jan 30, 2024
93 changes: 15 additions & 78 deletions src/databricks/labs/ucx/assessment/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,9 @@
from dataclasses import dataclass

from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import NotFound
from databricks.sdk.service.compute import ClusterDetails, ClusterSource, Policy
from databricks.sdk.service.compute import ClusterSource

from databricks.labs.ucx.assessment.crawlers import (
_AZURE_SP_CONF_FAILURE_MSG,
INCOMPATIBLE_SPARK_CONFIG_KEYS,
_azure_sp_conf_in_init_scripts,
_azure_sp_conf_present_check,
_get_init_script_data,
logger,
spark_version_compatibility,
)
from databricks.labs.ucx.assessment.crawlers import _check_cluster_failures, logger
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't import logger, initialize one in the top of the module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend


Expand All @@ -27,72 +18,7 @@ class ClusterInfo:
creator: str | None = None


class ClustersMixin:
_ws: WorkspaceClient

def _safe_get_cluster_policy(self, policy_id: str) -> Policy | None:
try:
return self._ws.cluster_policies.get(policy_id)
except NotFound:
logger.warning(f"The cluster policy was deleted: {policy_id}")
return None

def _check_spark_conf(self, cluster, failures):
for k in INCOMPATIBLE_SPARK_CONFIG_KEYS:
if k in cluster.spark_conf:
failures.append(f"unsupported config: {k}")
for value in cluster.spark_conf.values():
if "dbfs:/mnt" in value or "/dbfs/mnt" in value:
failures.append(f"using DBFS mount in configuration: {value}")
# Checking if Azure cluster config is present in spark config
if _azure_sp_conf_present_check(cluster.spark_conf):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")

def _check_cluster_policy(self, cluster, failures):
policy = self._safe_get_cluster_policy(cluster.policy_id)
if policy:
if policy.definition:
if _azure_sp_conf_present_check(json.loads(policy.definition)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
if policy.policy_family_definition_overrides:
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")

def _check_init_scripts(self, cluster, failures):
for init_script_info in cluster.init_scripts:
init_script_data = _get_init_script_data(self._ws, init_script_info)
if not init_script_data:
continue
if not _azure_sp_conf_in_init_scripts(init_script_data):
continue
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")

def _check_cluster_failures(self, cluster: ClusterDetails):
failures = []
cluster_info = ClusterInfo(
cluster_id=cluster.cluster_id if cluster.cluster_id else "",
cluster_name=cluster.cluster_name,
creator=cluster.creator_user_name,
success=1,
failures="[]",
)
support_status = spark_version_compatibility(cluster.spark_version)
if support_status != "supported":
failures.append(f"not supported DBR: {cluster.spark_version}")
if cluster.spark_conf is not None:
self._check_spark_conf(cluster, failures)
# Checking if Azure cluster config is present in cluster policies
if cluster.policy_id:
self._check_cluster_policy(cluster, failures)
if cluster.init_scripts:
self._check_init_scripts(cluster, failures)
cluster_info.failures = json.dumps(failures)
if len(failures) > 0:
cluster_info.success = 0
return cluster_info


class ClustersCrawler(CrawlerBase[ClusterInfo], ClustersMixin):
class ClustersCrawler(CrawlerBase[ClusterInfo]):
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
super().__init__(sbe, "hive_metastore", schema, "clusters", ClusterInfo)
self._ws = ws
Expand All @@ -110,7 +36,18 @@ def _assess_clusters(self, all_clusters):
f"Cluster {cluster.cluster_id} have Unknown creator, it means that the original creator "
f"has been deleted and should be re-created"
)
yield self._check_cluster_failures(cluster)
cluster_info = ClusterInfo(
cluster_id=cluster.cluster_id if cluster.cluster_id else "",
cluster_name=cluster.cluster_name,
creator=cluster.creator_user_name,
success=1,
failures="[]",
)
failures = _check_cluster_failures(self._ws, cluster, "cluster")
if len(failures) > 0:
cluster_info.success = 0
cluster_info.failures = json.dumps(failures)
yield cluster_info

def snapshot(self) -> Iterable[ClusterInfo]:
return self._snapshot(self._try_fetch, self._crawl)
Expand Down
73 changes: 73 additions & 0 deletions src/databricks/labs/ucx/assessment/crawlers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import base64
import json
import logging
import re

from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import NotFound
from databricks.sdk.service import compute
from databricks.sdk.service.compute import ClusterDetails, Policy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 9 is redundant because you've imported the whole compute package. use one style of imports, don't mix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -92,3 +96,72 @@
if (10, 0) <= version < (11, 3):
return "kinda works"
return "supported"


def _check_spark_conf(conf: dict[str, str], source) -> list[str]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these functions are private (starting with _). you cannot export private methods from a module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved those functions back to cluster mixin. I make check_cluster_failures and check_spark_conf public function now as the former is used by jobs.py and the latter is used by pipelines.py also.

failures = []
for k in INCOMPATIBLE_SPARK_CONFIG_KEYS:
if k in conf:
failures.append(f"unsupported config: {k}")
for value in conf.values():
if "dbfs:/mnt" in value or "/dbfs/mnt" in value:
failures.append(f"using DBFS mount in configuration: {value}")

Check warning on line 108 in src/databricks/labs/ucx/assessment/crawlers.py

View check run for this annotation

Codecov / codecov/patch

src/databricks/labs/ucx/assessment/crawlers.py#L108

Added line #L108 was not covered by tests
# Checking if Azure cluster config is present in spark config
if _azure_sp_conf_present_check(conf):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
return failures


def _safe_get_cluster_policy(ws: WorkspaceClient, policy_id: str) -> Policy | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, logical question - do we want to do diamond shape dependencies or not? e.g. _safe_get_cluster_policy - where it should live? in crawlers.py or in clusters.py? what about _check_cluster_failures?

flowchart TD
    assessment --> crawlers
    crawlers --> clusters
    crawlers --> jobs
    crawlers --> init_scripts
    crawlers --> pipelines
    crawlers --> azure_spns
    clusters --> runtime.py
    jobs --> runtime.py
    init_scripts --> runtime.py
    pipelines --> runtime.py
    azure_spns --> runtime.py
Loading

or something like this:

flowchart TD
    assessment --> crawlers
    azure_spns -->|mixin| clusters
    crawlers --> clusters
    clusters -->|mixin| jobs
    crawlers --> jobs
    azure_spns -->|mixin| init_scripts
    crawlers --> init_scripts
    jobs -->|mixin| pipelines
    crawlers --> pipelines
    crawlers --> azure_spns
    clusters --> runtime.py
    jobs --> runtime.py
    init_scripts --> runtime.py
    pipelines --> runtime.py
    azure_spns --> runtime.py
Loading

Copy link
Contributor Author

@qziyuan qziyuan Jan 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to select the first dependency structure, it's more clear to me.
Per my understanding cluster.py crawl and check all the all-purpose clusters, jobs.py crawl and check all jobs (so far it only check job cluster, but we may need to check the job code in the future), pipelines.py crawl and check all DLT pipelines (right now it only checks pipeline config, but it should check the pipeline cluster as well).
It's clear for them to inherent the _check_cluster_failures function to check spark conf, init script, cluster policy, instead of letting jobs and pipelines to inherent _check_cluster_failures from the clusters.
There are some logical that may stay in the domain-groupped modules, because they are not commonly shared across crawlers:

  • check all-purpose cluster mode (we don't have this logical yet, but may have it in the future). The cluster may need to be put into shared mode which has more limitations, but this check does not apply to job cluster.
  • check job code (we don't have this logical yet, but may have it in the future)

try:
return ws.cluster_policies.get(policy_id)
except NotFound:
logger.warning(f"The cluster policy was deleted: {policy_id}")
return None


def _check_cluster_policy(ws: WorkspaceClient, cluster, source):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add type annotations to top-level members if we really don't want mixins

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, type annotations added even we go back to mixin.

failures = []
policy = _safe_get_cluster_policy(ws, cluster.policy_id)
if policy:
if policy.definition:
if _azure_sp_conf_present_check(json.loads(policy.definition)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
if policy.policy_family_definition_overrides:
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
return failures


def _check_cluster_init_script(ws: WorkspaceClient, init_scripts, source):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we're not doing mixins, then why this function is defined so far away from _get_init_script_data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move them to mixin and make these two functions next to each other.

failures = []
for init_script_info in init_scripts:
init_script_data = _get_init_script_data(ws, init_script_info)
failures.extend(_check_init_script(init_script_data, source))
return failures


def _check_init_script(init_script_data, source):
failures = []
if not init_script_data:
return failures
if _azure_sp_conf_in_init_scripts(init_script_data):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
return failures


def _check_cluster_failures(ws: WorkspaceClient, cluster: ClusterDetails | compute.ClusterSpec, source):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't do union types, they already caused a lot of bugs.

convert cluster spec into cluster details instead:

https://github.com/databrickslabs/ucx/blob/main/src/databricks/labs/ucx/assessment/jobs.py#L88-L89

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, it was a miss when merge with #845

failures = []

support_status = spark_version_compatibility(cluster.spark_version)
if support_status != "supported":
failures.append(f"not supported DBR: {cluster.spark_version}")
if cluster.spark_conf is not None:
failures.extend(_check_spark_conf(cluster.spark_conf, source))
# Checking if Azure cluster config is present in cluster policies
if cluster.policy_id:
failures.extend(_check_cluster_policy(ws, cluster, source))
if cluster.init_scripts:
failures.extend(_check_cluster_init_script(ws, cluster.init_scripts, source))

return failures
11 changes: 3 additions & 8 deletions src/databricks/labs/ucx/assessment/init_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@

from databricks.sdk import WorkspaceClient

from databricks.labs.ucx.assessment.crawlers import (
_AZURE_SP_CONF_FAILURE_MSG,
_azure_sp_conf_in_init_scripts,
logger,
)
from databricks.labs.ucx.assessment.crawlers import _check_init_script, logger
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend


Expand Down Expand Up @@ -52,9 +48,8 @@ def _assess_global_init_scripts(self, all_global_init_scripts):
global_init_script = base64.b64decode(script.script).decode("utf-8")
if not global_init_script:
continue
if _azure_sp_conf_in_init_scripts(global_init_script):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} global init script.")
global_init_script_info.failures = json.dumps(failures)
failures.extend(_check_init_script(global_init_script, "global init script"))
global_init_script_info.failures = json.dumps(failures)

if len(failures) > 0:
global_init_script_info.success = 0
Expand Down
12 changes: 4 additions & 8 deletions src/databricks/labs/ucx/assessment/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
from dataclasses import dataclass

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.compute import ClusterDetails
from databricks.sdk.service.jobs import BaseJob

from databricks.labs.ucx.assessment.clusters import ClustersMixin
from databricks.labs.ucx.assessment.crawlers import logger
from databricks.labs.ucx.assessment.crawlers import _check_cluster_failures, logger
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend


Expand All @@ -20,7 +18,7 @@ class JobInfo:
creator: str | None = None


class JobsMixin(ClustersMixin):
class JobsMixin:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, shouldn't it be named CheckJobsMixin?

@staticmethod
def _get_cluster_configs_from_all_jobs(all_jobs, all_clusters_by_id):
for j in all_jobs:
Expand Down Expand Up @@ -85,10 +83,8 @@ def _assess_jobs(self, all_jobs: list[BaseJob], all_clusters_by_id) -> Iterable[
job_id = job.job_id
if not job_id:
continue
cluster_details = ClusterDetails.from_dict(cluster_config.as_dict())
cluster_failures = self._check_cluster_failures(cluster_details)
for failure in json.loads(cluster_failures.failures):
job_assessment[job_id].add(failure)
cluster_failures = _check_cluster_failures(self._ws, cluster_config, "Job cluster")
job_assessment[job_id].update(cluster_failures)

# TODO: next person looking at this - rewrite, as this code makes no sense
for job_key in job_details.keys(): # pylint: disable=consider-using-dict-items,consider-iterating-dictionary
Expand Down
9 changes: 2 additions & 7 deletions src/databricks/labs/ucx/assessment/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@

from databricks.sdk import WorkspaceClient

from databricks.labs.ucx.assessment.crawlers import (
_AZURE_SP_CONF_FAILURE_MSG,
_azure_sp_conf_present_check,
logger,
)
from databricks.labs.ucx.assessment.crawlers import _check_spark_conf, logger
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend


Expand Down Expand Up @@ -50,8 +46,7 @@ def _assess_pipelines(self, all_pipelines) -> Iterable[PipelineInfo]:
assert pipeline_response.spec is not None
pipeline_config = pipeline_response.spec.configuration
if pipeline_config:
if _azure_sp_conf_present_check(pipeline_config):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} pipeline.")
failures.extend(_check_spark_conf(pipeline_config, "pipeline"))

pipeline_info.failures = json.dumps(failures)
if len(failures) > 0:
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/assessment/clusters/job-source-cluster.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[
{
"autoscale": {
"max_workers": 6,
"min_workers": 1
},
"cluster_source": "JOB",
"creator_user_name":"[email protected]",
"cluster_id": "0123-190044-1122334422",
"cluster_name": "Single User Cluster Name",
"policy_id": "single-user-with-spn",
"spark_version": "9.3.x-cpu-ml-scala2.12",
"spark_conf" : {
"spark.databricks.delta.preview.enabled": "true"
},
"spark_context_id":"5134472582179565315"
},
{
"autoscale": {
"max_workers": 6,
"min_workers": 1
},
"creator_user_name":"[email protected]",
"cluster_id": "0123-190044-1122334411",
"cluster_name": "Single User Cluster Name",
"policy_id": "azure-oauth",
"spark_version": "13.3.x-cpu-ml-scala2.12",
"spark_conf" : {
"spark.databricks.delta.preview.enabled": "true"
},
"spark_context_id":"5134472582179565315"
}
]
22 changes: 22 additions & 0 deletions tests/unit/assessment/test_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,25 @@ def test_cluster_with_multiple_failures():
failures = json.loads(result_set[0].failures)
assert 'unsupported config: spark.databricks.passthrough.enabled' in failures
assert 'not supported DBR: 9.3.x-cpu-ml-scala2.12' in failures


def test_cluster_with_job_source():
ws = workspace_client_mock(clusters="job-source-cluster.json")
crawler = ClustersCrawler(ws, MockBackend(), "ucx")
result_set = list(crawler.snapshot())

assert len(result_set) == 1
assert result_set[0].cluster_id == "0123-190044-1122334411"


def test_try_fetch():
ws = workspace_client_mock(clusters="assortment-conf.json")
mockBackend = MagicMock()
mockBackend.fetch.return_value = [("000", 1, "123")]
crawler = ClustersCrawler(ws, mockBackend, "ucx")
result_set = list(crawler.snapshot())

assert len(result_set) == 1
assert result_set[0].cluster_id == "000"
assert result_set[0].success == 1
assert result_set[0].failures == "123"
Loading