Skip to content

Commit

Permalink
Replaced JVM dependency with new API call.
Browse files Browse the repository at this point in the history
  • Loading branch information
FastLee committed Dec 13, 2024
1 parent 2bb401f commit e651bcd
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 35 deletions.
13 changes: 11 additions & 2 deletions src/databricks/labs/ucx/aws/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,17 @@ def get_roles_to_migrate(self) -> list[AWSCredentialCandidate]:
Identify the roles that need to be migrated to UC from the UC compatible roles list.
"""
external_locations = self._locations.snapshot()
logger.info(f"Found {len(external_locations)} external locations")
compatible_roles = self.load_uc_compatible_roles()
roles: dict[str, AWSCredentialCandidate] = {}
for external_location in external_locations:
path = PurePath(external_location.location)
for role in compatible_roles:
if not (path.match(role.resource_path) or path.match(role.resource_path + "/*")):
if not (
PurePath(role.resource_path) in path.parents
or path.match(role.resource_path)
or path.match(role.resource_path + "/*")
):
continue
if role.role_arn not in roles:
roles[role.role_arn] = AWSCredentialCandidate(
Expand Down Expand Up @@ -374,11 +379,15 @@ def create_uber_principal(self, prompts: Prompts):
logger.error(f"Failed to assign instance profile to cluster policy {iam_role_name}")
self._aws_resources.delete_instance_profile(iam_role_name, iam_role_name)

def _clean_location_name(self, location: str) -> str:
# Remove leading s3:// s3a:// and trailing /
return location.replace("s3://", "").replace("s3a://", "").replace("/", "_").replace(":", "_").replace(".", "_")

def _generate_role_name(self, single_role: bool, role_name: str, location: str) -> str:
if single_role:
metastore_id = self._ws.metastores.current().as_dict()["metastore_id"]
return f"{role_name}_{metastore_id}"
return f"{role_name}_{location[5:]}"
return f"{role_name}_{self._clean_location_name(location)}"

def delete_uc_role(self, role_name: str):
self._aws_resources.delete_role(role_name)
8 changes: 4 additions & 4 deletions src/databricks/labs/ucx/aws/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
self._principal_acl = principal_acl
# When HMS federation is enabled, the fallback bit is set for all the
# locations which are created by UCX.
self._enable_fallback_mode = enable_hms_federation
self._enable_hms_federation = enable_hms_federation

def run(self) -> None:
"""
Expand All @@ -38,7 +38,7 @@ def run(self) -> None:
Create external location for the path using the credential identified
"""
credential_dict = self._get_existing_credentials_dict()
external_locations = self._external_locations.snapshot()
external_locations = list(self._external_locations.snapshot())
existing_external_locations = self._ws.external_locations.list()
existing_paths = []
for external_location in existing_external_locations:
Expand All @@ -56,7 +56,7 @@ def run(self) -> None:
path,
credential_dict[role_arn],
skip_validation=True,
fallback=self._enable_fallback_mode,
fallback=self._enable_hms_federation,
)
self._principal_acl.apply_location_acl()

Expand Down Expand Up @@ -91,7 +91,7 @@ def _identify_missing_external_locations(
path = role.resource_path
if path.endswith("/*"):
path = path[:-2]
if new_path.match(path + "/*") or new_path.match(path):
if PurePath(path) in new_path.parents or new_path.match(path + "/*") or new_path.match(path):
matching_role = role.role_arn
continue
if matching_role:
Expand Down
4 changes: 2 additions & 2 deletions src/databricks/labs/ucx/azure/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
self._resource_permissions = resource_permissions
self._azurerm = azurerm
self._principal_acl = principal_acl
self._enable_fallback_mode = enable_hms_federation
self._enable_hms_federation = enable_hms_federation

def _app_id_credential_name_mapping(self) -> tuple[dict[str, str], dict[str, str]]:
# list all storage credentials.
Expand Down Expand Up @@ -128,7 +128,7 @@ def _create_external_location_helper(
comment=comment,
read_only=read_only,
skip_validation=skip_validation,
fallback=self._enable_fallback_mode,
fallback=self._enable_hms_federation,
)
return url
except InvalidParameterValue as invalid:
Expand Down
2 changes: 1 addition & 1 deletion src/databricks/labs/ucx/contexts/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,6 @@ def mounts_crawler(self) -> MountsCrawler:
self.sql_backend,
self.workspace_client,
self.inventory_database,
self.config.enable_hms_federation,
)

@cached_property
Expand All @@ -394,6 +393,7 @@ def external_locations(self) -> ExternalLocations:
self.inventory_database,
self.tables_crawler,
self.mounts_crawler,
self.config.enable_hms_federation,
)

@cached_property
Expand Down
42 changes: 29 additions & 13 deletions src/databricks/labs/ucx/hive_metastore/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _parse_location(cls, location: str | None) -> list[str]:
parse_result = cls._parse_url(location.rstrip("/"))
if not parse_result:
return []
parts = [parse_result.scheme, parse_result.netloc]
parts = [parse_result.scheme.replace("s3a", "s3"), parse_result.netloc]
for part in parse_result.path.split("/"):
if not part:
continue # remove empty strings
Expand Down Expand Up @@ -154,17 +154,36 @@ def __init__(
schema: str,
tables_crawler: TablesCrawler,
mounts_crawler: 'MountsCrawler',
enable_hms_federation: bool = False,
):
super().__init__(sql_backend, "hive_metastore", schema, "external_locations", ExternalLocation)
self._ws = ws
self._tables_crawler = tables_crawler
self._mounts_crawler = mounts_crawler
self._enable_hms_federation = enable_hms_federation

@cached_property
def _mounts_snapshot(self) -> list['Mount']:
"""Returns all mounts, sorted by longest prefixes first."""
return sorted(self._mounts_crawler.snapshot(), key=lambda _: (len(_.name), _.name), reverse=True)

def get_dbfs_root(self) -> ExternalLocation | None:
"""
Get the root location of the DBFS
Returns:
Cloud storage root location for dbfs
"""
logger.debug("Retrieving DBFS root location")
response = self._ws.api_client.do("GET", "/api/2.0/dbfs/resolve-path", query={"path": "dbfs:/"})
if isinstance(response, dict):
resolved_path = response.get("resolved_path")
if resolved_path:
resolved_path = re.sub(r"^s3a:/", r"s3:/", resolved_path)
return ExternalLocation(resolved_path, 0)
return None

def _external_locations(self) -> Iterable[ExternalLocation]:
trie = LocationTrie()
for table in self._tables_crawler.snapshot():
Expand Down Expand Up @@ -194,10 +213,6 @@ def _external_locations(self) -> Iterable[ExternalLocation]:
external_locations.append(external_location)
continue
queue.extend(curr.children.values())
if self._mounts_snapshot:
root_dbfs = self._mounts_snapshot[-1]
if root_dbfs.name == '/':
external_locations.append(ExternalLocation(root_dbfs.source, 0))
return sorted(external_locations, key=lambda _: _.location)

def _resolve_location(self, table: Table) -> Table:
Expand Down Expand Up @@ -264,6 +279,14 @@ def _try_fetch(self) -> Iterable[ExternalLocation]:
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield ExternalLocation(*row)

def snapshot(self, *, force_refresh: bool = False) -> list[ExternalLocation]:
external_locations = list(super().snapshot(force_refresh=force_refresh))
if self._enable_hms_federation:
dbfs_root = self.get_dbfs_root()
if dbfs_root:
external_locations.append(dbfs_root)
return external_locations

@staticmethod
def _get_ext_location_definitions(missing_locations: list[ExternalLocation]) -> list:
tf_script = []
Expand Down Expand Up @@ -363,11 +386,9 @@ def __init__(
sql_backend: SqlBackend,
ws: WorkspaceClient,
inventory_database: str,
enable_hms_federation: bool = False,
):
super().__init__(sql_backend, "hive_metastore", inventory_database, "mounts", Mount)
self._dbutils = ws.dbutils
self._enable_hms_federation = enable_hms_federation

@staticmethod
def _deduplicate_mounts(mounts: list) -> list:
Expand Down Expand Up @@ -396,6 +417,7 @@ def _jvm(self):
return None

def _resolve_dbfs_root(self) -> Mount | None:
# TODO: Consider deprecating this method and rely on the new API call
# pylint: disable=broad-exception-caught,too-many-try-statements
try:
jvm = self._jvm
Expand All @@ -419,12 +441,6 @@ def _crawl(self) -> Iterable[Mount]:
try:
for mount_point, source, _ in self._dbutils.fs.mounts():
mounts.append(Mount(mount_point, source))
if self._enable_hms_federation:
root_mount = self._resolve_dbfs_root()
if root_mount:
# filter out DatabricksRoot, otherwise ExternalLocations.resolve_mount() won't work
mounts = list(filter(lambda _: _.source != 'DatabricksRoot', mounts))
mounts.append(root_mount)
except Exception as error: # pylint: disable=broad-except
if "com.databricks.backend.daemon.dbutils.DBUtilsCore.mounts() is not whitelisted" in str(error):
logger.warning(
Expand Down
26 changes: 14 additions & 12 deletions tests/unit/hive_metastore/test_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,20 +725,22 @@ def my_side_effect(path, **_):


def test_resolve_dbfs_root_in_hms_federation():
jvm = Mock()
sql_backend = MockBackend()
client = create_autospec(WorkspaceClient)
client.dbutils.fs.mounts.return_value = [MountInfo('/', 'DatabricksRoot', '')]

mounts_crawler = MountsCrawler(sql_backend, client, "test", enable_hms_federation=True)
mounts_crawler.__dict__['_jvm'] = jvm

hms_fed_dbfs_utils = jvm.com.databricks.sql.managedcatalog.connections.HmsFedDbfsUtils
hms_fed_dbfs_utils.resolveDbfsPath().get().toString.return_value = 's3://original/bucket/user/hive/warehouse'

mounts = mounts_crawler.snapshot()
ws = create_autospec(WorkspaceClient)
ws.api_client.do.return_value = {"resolved_path": "s3:/foo/bar"}

assert [Mount("/", 's3://original/bucket/')] == mounts
tables_crawler = create_autospec(TablesCrawler)
tables_crawler.snapshot.return_value = [
table_factory(["s3://test_location/test1/table1", ""]),
]
mounts_crawler = create_autospec(MountsCrawler)
external_locations = ExternalLocations(
ws, sql_backend, "test", tables_crawler, mounts_crawler, enable_hms_federation=True
)
results = external_locations.snapshot()
mounts_crawler.snapshot.assert_not_called()
assert results == [ExternalLocation("s3://test_location/test1", 1), ExternalLocation("s3:/foo/bar", 0)]
assert len(results) == 2


def test_mount_listing_misplaced_flat_file():
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def test_create_aws_uber_principal_calls_dbutils_fs_mounts(ws) -> None:
)
prompts = MockPrompts({})
create_uber_principal(ws, prompts, ctx=ctx)
ws.dbutils.fs.mounts.assert_called_once()
ws.dbutils.fs.mounts.assert_not_called()


def test_migrate_locations_raises_value_error_for_unsupported_cloud_provider(ws) -> None:
Expand Down

0 comments on commit e651bcd

Please sign in to comment.