Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip listing built-in catalogs to update table migration process #3464

Merged
merged 6 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from databricks.labs.lsql.backends import SqlBackend
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import DatabricksError, NotFound
from databricks.sdk.service.catalog import CatalogInfo
from databricks.sdk.service.catalog import CatalogInfo, CatalogInfoSecurableKind, SchemaInfo

from databricks.labs.ucx.framework.crawlers import CrawlerBase
from databricks.labs.ucx.framework.utils import escape_sql_identifier
Expand Down Expand Up @@ -79,6 +79,11 @@ class TableMigrationStatusRefresher(CrawlerBase[TableMigrationStatus]):
properties for the presence of the marker.
"""

_skip_catalog_securable_kinds = [
CatalogInfoSecurableKind.CATALOG_INTERNAL,
CatalogInfoSecurableKind.CATALOG_SYSTEM,
]

def __init__(self, ws: WorkspaceClient, sql_backend: SqlBackend, schema, tables_crawler: TablesCrawler):
super().__init__(sql_backend, "hive_metastore", schema, "migration_status", TableMigrationStatus)
self._ws = ws
Expand All @@ -90,6 +95,8 @@ def index(self, *, force_refresh: bool = False) -> TableMigrationIndex:
def get_seen_tables(self) -> dict[str, str]:
seen_tables: dict[str, str] = {}
for schema in self._iter_schemas():
if schema.catalog_name is None or schema.name is None:
continue
try:
# ws.tables.list returns Iterator[TableInfo], so we need to convert it to a list in order to catch the exception
tables = list(self._ws.tables.list(catalog_name=schema.catalog_name, schema_name=schema.name))
Expand Down Expand Up @@ -136,9 +143,7 @@ def _crawl(self) -> Iterable[TableMigrationStatus]:
src_schema = table.database.lower()
src_table = table.name.lower()
table_migration_status = TableMigrationStatus(
src_schema=src_schema,
src_table=src_table,
update_ts=str(timestamp),
src_schema=src_schema, src_table=src_table, update_ts=str(timestamp)
)
if table.key in reverse_seen and self.is_migrated(src_schema, src_table):
target_table = reverse_seen[table.key]
Expand All @@ -157,12 +162,17 @@ def _try_fetch(self) -> Iterable[TableMigrationStatus]:

def _iter_catalogs(self) -> Iterable[CatalogInfo]:
try:
yield from self._ws.catalogs.list()
for catalog in self._ws.catalogs.list():
if catalog.securable_kind in self._skip_catalog_securable_kinds:
continue
yield catalog
except DatabricksError as e:
logger.error("Cannot list catalogs", exc_info=e)

def _iter_schemas(self):
def _iter_schemas(self) -> Iterable[SchemaInfo]:
for catalog in self._iter_catalogs():
if catalog.name is None:
continue
try:
yield from self._ws.schemas.list(catalog_name=catalog.name)
except NotFound:
Expand Down
57 changes: 56 additions & 1 deletion tests/unit/hive_metastore/test_table_migration_status.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections.abc import Iterable
from unittest.mock import create_autospec

import pytest
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import BadRequest, DatabricksError, NotFound
from databricks.sdk.service.catalog import CatalogInfo, SchemaInfo
from databricks.sdk.service.catalog import CatalogInfoSecurableKind, CatalogInfo, SchemaInfo, TableInfo

from databricks.labs.ucx.hive_metastore.tables import TablesCrawler
from databricks.labs.ucx.hive_metastore.table_migration_status import TableMigrationStatusRefresher
Expand Down Expand Up @@ -64,3 +65,57 @@ def test_table_migration_status_refresher_get_seen_tables_handles_errors_on_tabl
ws.schemas.list.assert_called_once()
ws.tables.list.assert_called_once()
tables_crawler.snapshot.assert_not_called()


@pytest.mark.parametrize(
"securable_kind",
[
CatalogInfoSecurableKind.CATALOG_INTERNAL,
CatalogInfoSecurableKind.CATALOG_SYSTEM,
],
)
def test_table_migration_status_refresher_get_seen_tables_skips_builtin_catalog(
mock_backend, securable_kind: CatalogInfoSecurableKind
) -> None:
ws = create_autospec(WorkspaceClient)
ws.catalogs.list.return_value = [
CatalogInfo(name="test"),
CatalogInfo(name="system", securable_kind=securable_kind),
]

def schemas_list(catalog_name: str) -> Iterable[SchemaInfo]:
schemas = [
SchemaInfo(catalog_name="test", name="test"),
SchemaInfo(catalog_name="system", name="access"),
]
for schema in schemas:
if schema.catalog_name == catalog_name:
yield schema

def tables_list(catalog_name: str, schema_name: str) -> Iterable[TableInfo]:
tables = [
TableInfo(
full_name="test.test.test",
catalog_name="test",
schema_name="test",
name="test",
properties={"upgraded_from": "test"},
),
TableInfo(catalog_name="system", schema_name="access", name="audit"),
]
for table in tables:
if table.catalog_name == catalog_name and table.schema_name == schema_name:
yield table

ws.schemas.list.side_effect = schemas_list
ws.tables.list.side_effect = tables_list
tables_crawler = create_autospec(TablesCrawler)
refresher = TableMigrationStatusRefresher(ws, mock_backend, "test", tables_crawler)

seen_tables = refresher.get_seen_tables()

assert seen_tables == {"test.test.test": "test"}
ws.catalogs.list.assert_called_once()
ws.schemas.list.assert_called_once_with(catalog_name="test") # System is NOT called
ws.tables.list.assert_called()
tables_crawler.snapshot.assert_not_called()
Loading