Skip to content

Add "what" property for migration to scope down table migrations. #856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 30, 2024
Merged
14 changes: 8 additions & 6 deletions src/databricks/labs/ucx/hive_metastore/table_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from databricks.labs.ucx.framework.crawlers import SqlBackend
from databricks.labs.ucx.hive_metastore import TablesCrawler
from databricks.labs.ucx.hive_metastore.mapping import Rule, TableMapping
from databricks.labs.ucx.hive_metastore.tables import MigrationCount, Table
from databricks.labs.ucx.hive_metastore.tables import MigrationCount, Table, What

logger = logging.getLogger(__name__)


class TablesMigrate:

def __init__(
self,
tc: TablesCrawler,
Expand All @@ -34,23 +35,24 @@ def __init__(
self._tm = tm
self._seen_tables: dict[str, str] = {}

def migrate_tables(self):
def migrate_tables(self, *, what: What | None = None):
self._init_seen_tables()
tables_to_migrate = self._tm.get_tables_to_migrate(self._tc)
tasks = []
for table in tables_to_migrate:
tasks.append(partial(self._migrate_table, table.src, table.rule))
if not what or table.src.what == what:
tasks.append(partial(self._migrate_table, table.src, table.rule))
Threads.strict("migrate tables", tasks)

def _migrate_table(self, src_table: Table, rule: Rule):
if self._table_already_upgraded(rule.as_uc_table_key):
logger.info(f"Table {src_table.key} already upgraded to {rule.as_uc_table_key}")
return True
if src_table.kind == "TABLE" and src_table.table_format == "DELTA" and src_table.is_dbfs_root:
if src_table.what == What.DBFS_ROOT_DELTA:
return self._migrate_dbfs_root_table(src_table, rule)
if src_table.kind == "TABLE" and src_table.is_format_supported_for_sync:
if src_table.what == What.EXTERNAL_SYNC:
return self._migrate_external_table(src_table, rule)
if src_table.kind == "VIEW":
if src_table.what == What.VIEW:
return self._migrate_view(src_table, rule)
logger.info(f"Table {src_table.key} is not supported for migration")
return True
Expand Down
27 changes: 27 additions & 0 deletions src/databricks/labs/ucx/hive_metastore/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial

from databricks.labs.blueprint.parallel import Threads
Expand All @@ -14,6 +15,16 @@
logger = logging.getLogger(__name__)


class What(Enum):
EXTERNAL_SYNC = auto()
EXTERNAL_NO_SYNC = auto()
DBFS_ROOT_DELTA = auto()
DBFS_ROOT_NON_DELTA = auto()
VIEW = auto()
DB_DATASET = auto()
UNKNOWN = auto()


@dataclass
class Table:
catalog: str
Expand Down Expand Up @@ -96,6 +107,22 @@ def is_databricks_dataset(self) -> bool:
return True
return False

@property
def what(self) -> What:
if self.is_databricks_dataset:
return What.DB_DATASET
if self.is_dbfs_root and self.table_format == "DELTA":
return What.DBFS_ROOT_DELTA
if self.is_dbfs_root:
return What.DBFS_ROOT_NON_DELTA
if self.kind == "TABLE" and self.is_format_supported_for_sync:
return What.EXTERNAL_SYNC
if self.kind == "TABLE":
return What.EXTERNAL_NO_SYNC
if self.kind == "VIEW":
return What.VIEW
return What.UNKNOWN

def sql_migrate_external(self, target_table_key):
return f"SYNC TABLE {escape_sql_identifier(target_table_key)} FROM {escape_sql_identifier(self.key)};"

Expand Down
6 changes: 6 additions & 0 deletions tests/integration/hive_metastore/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from databricks.sdk.retries import retried

from databricks.labs.ucx.hive_metastore import TablesCrawler
from databricks.labs.ucx.hive_metastore.tables import What

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,8 +39,13 @@ def test_describe_all_tables_in_databases(ws, sql_backend, inventory_schema, mak

assert len(all_tables) >= 5
assert all_tables[non_delta.full_name].table_format == "JSON"
assert all_tables[non_delta.full_name].what == What.DB_DATASET
assert all_tables[managed_table.full_name].object_type == "MANAGED"
assert all_tables[managed_table.full_name].what == What.DBFS_ROOT_DELTA
assert all_tables[tmp_table.full_name].object_type == "MANAGED"
assert all_tables[tmp_table.full_name].what == What.DBFS_ROOT_DELTA
assert all_tables[external_table.full_name].object_type == "EXTERNAL"
assert all_tables[external_table.full_name].what == What.EXTERNAL_NO_SYNC
assert all_tables[view.full_name].object_type == "VIEW"
assert all_tables[view.full_name].view_text == "SELECT 2+2 AS four"
assert all_tables[view.full_name].what == What.VIEW
72 changes: 72 additions & 0 deletions tests/unit/hive_metastore/test_table_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MigrationCount,
Table,
TablesCrawler,
What,
)

from ..framework.mocks import MockBackend
Expand Down Expand Up @@ -66,6 +67,25 @@ def test_migrate_dbfs_root_tables_should_produce_proper_queries():
)


def test_migrate_dbfs_root_tables_should_be_skipped_when_upgrading_external():
errors = {}
rows = {}
backend = MockBackend(fails_on_first=errors, rows=rows)
table_crawler = TablesCrawler(backend, "inventory_database")
client = MagicMock()
table_mapping = create_autospec(TableMapping)
table_mapping.get_tables_to_migrate.return_value = [
TableToMigrate(
Table("hive_metastore", "db1_src", "managed_dbfs", "MANAGED", "DELTA", "dbfs:/some_location"),
Rule("workspace", "ucx_default", "db1_src", "db1_dst", "managed_dbfs", "managed_dbfs"),
),
]
table_migrate = TablesMigrate(table_crawler, client, backend, table_mapping)
table_migrate.migrate_tables(what=What.EXTERNAL_SYNC)

assert len(backend.queries) == 0


def test_migrate_external_tables_should_produce_proper_queries():
errors = {}
rows = {}
Expand All @@ -87,6 +107,58 @@ def test_migrate_external_tables_should_produce_proper_queries():
]


def test_migrate_already_upgraded_table_should_produce_no_queries():
errors = {}
rows = {}
backend = MockBackend(fails_on_first=errors, rows=rows)
table_crawler = TablesCrawler(backend, "inventory_database")
client = create_autospec(WorkspaceClient)
client.catalogs.list.return_value = [CatalogInfo(name="cat1")]
client.schemas.list.return_value = [
SchemaInfo(catalog_name="cat1", name="test_schema1"),
]
client.tables.list.return_value = [
TableInfo(
catalog_name="cat1",
schema_name="schema1",
name="dest1",
full_name="cat1.schema1.dest1",
properties={"upgraded_from": "hive_metastore.db1_src.external_src"},
),
]

table_mapping = create_autospec(TableMapping)
table_mapping.get_tables_to_migrate.return_value = [
TableToMigrate(
Table("hive_metastore", "db1_src", "external_src", "EXTERNAL", "DELTA"),
Rule("workspace", "cat1", "db1_src", "schema1", "external_src", "dest1"),
)
]
table_migrate = TablesMigrate(table_crawler, client, backend, table_mapping)
table_migrate.migrate_tables()

assert len(backend.queries) == 0


def test_migrate_unsupported_format_table_should_produce_no_queries():
errors = {}
rows = {}
backend = MockBackend(fails_on_first=errors, rows=rows)
table_crawler = TablesCrawler(backend, "inventory_database")
client = create_autospec(WorkspaceClient)
table_mapping = create_autospec(TableMapping)
table_mapping.get_tables_to_migrate.return_value = [
TableToMigrate(
Table("hive_metastore", "db1_src", "external_src", "EXTERNAL", "UNSUPPORTED_FORMAT"),
Rule("workspace", "cat1", "db1_src", "schema1", "external_src", "dest1"),
)
]
table_migrate = TablesMigrate(table_crawler, client, backend, table_mapping)
table_migrate.migrate_tables()

assert len(backend.queries) == 0


def test_migrate_view_should_produce_proper_queries():
errors = {}
rows = {}
Expand Down
142 changes: 92 additions & 50 deletions tests/unit/hive_metastore/test_tables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from databricks.labs.ucx.hive_metastore.tables import Table, TablesCrawler
from databricks.labs.ucx.hive_metastore.tables import Table, TablesCrawler, What

from ..framework.mocks import MockBackend

Expand Down Expand Up @@ -136,52 +136,94 @@ def test_tables_returning_error_when_describing():
assert len(results) == 1


def test_is_dbfs_root():
assert Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/somelocation/tablename").is_dbfs_root
assert Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/somelocation/tablename").is_dbfs_root
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/mnt/somelocation/tablename").is_dbfs_root
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/mnt/somelocation/tablename").is_dbfs_root
assert not Table(
"a", "b", "c", "MANAGED", "DELTA", location="dbfs:/databricks-datasets/somelocation/tablename"
).is_dbfs_root
assert not Table(
"a", "b", "c", "MANAGED", "DELTA", location="/dbfs/databricks-datasets/somelocation/tablename"
).is_dbfs_root
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="s3:/somelocation/tablename").is_dbfs_root
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="adls:/somelocation/tablename").is_dbfs_root


def test_is_db_dataset():
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/somelocation/tablename").is_databricks_dataset
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/somelocation/tablename").is_databricks_dataset
assert not Table(
"a", "b", "c", "MANAGED", "DELTA", location="dbfs:/mnt/somelocation/tablename"
).is_databricks_dataset
assert not Table(
"a", "b", "c", "MANAGED", "DELTA", location="/dbfs/mnt/somelocation/tablename"
).is_databricks_dataset
assert Table(
"a", "b", "c", "MANAGED", "DELTA", location="dbfs:/databricks-datasets/somelocation/tablename"
).is_databricks_dataset
assert Table(
"a", "b", "c", "MANAGED", "DELTA", location="/dbfs/databricks-datasets/somelocation/tablename"
).is_databricks_dataset
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="s3:/somelocation/tablename").is_databricks_dataset
assert not Table("a", "b", "c", "MANAGED", "DELTA", location="adls:/somelocation/tablename").is_databricks_dataset


def test_is_supported_for_sync():
assert Table(
"a", "b", "c", "EXTERNAL", "DELTA", location="dbfs:/somelocation/tablename"
).is_format_supported_for_sync
assert Table("a", "b", "c", "EXTERNAL", "CSV", location="dbfs:/somelocation/tablename").is_format_supported_for_sync
assert Table(
"a", "b", "c", "EXTERNAL", "TEXT", location="dbfs:/somelocation/tablename"
).is_format_supported_for_sync
assert Table("a", "b", "c", "EXTERNAL", "ORC", location="dbfs:/somelocation/tablename").is_format_supported_for_sync
assert Table(
"a", "b", "c", "EXTERNAL", "JSON", location="dbfs:/somelocation/tablename"
).is_format_supported_for_sync
assert not (
Table("a", "b", "c", "EXTERNAL", "AVRO", location="dbfs:/somelocation/tablename").is_format_supported_for_sync
)
@pytest.mark.parametrize(
'table,dbfs_root,what',
[
(Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/somelocation/tablename"), True, What.DBFS_ROOT_DELTA),
(
Table("a", "b", "c", "MANAGED", "PARQUET", location="dbfs:/somelocation/tablename"),
True,
What.DBFS_ROOT_NON_DELTA,
),
(Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/somelocation/tablename"), True, What.DBFS_ROOT_DELTA),
(
Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/mnt/somelocation/tablename"),
False,
What.EXTERNAL_SYNC,
),
(
Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/mnt/somelocation/tablename"),
False,
What.EXTERNAL_SYNC,
),
(
Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/databricks-datasets/somelocation/tablename"),
False,
What.DB_DATASET,
),
(
Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/databricks-datasets/somelocation/tablename"),
False,
What.DB_DATASET,
),
(Table("a", "b", "c", "MANAGED", "DELTA", location="s3:/somelocation/tablename"), False, What.EXTERNAL_SYNC),
(Table("a", "b", "c", "MANAGED", "DELTA", location="adls:/somelocation/tablename"), False, What.EXTERNAL_SYNC),
],
)
def test_is_dbfs_root(table, dbfs_root, what):
assert table.is_dbfs_root == dbfs_root
assert table.what == what


@pytest.mark.parametrize(
'table,db_dataset',
[
(Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/somelocation/tablename"), False),
(Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/somelocation/tablename"), False),
(Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/mnt/somelocation/tablename"), False),
(Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/mnt/somelocation/tablename"), False),
(Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/databricks-datasets/somelocation/tablename"), True),
(Table("a", "b", "c", "MANAGED", "DELTA", location="/dbfs/databricks-datasets/somelocation/tablename"), True),
(Table("a", "b", "c", "MANAGED", "DELTA", location="s3:/somelocation/tablename"), False),
(Table("a", "b", "c", "MANAGED", "DELTA", location="adls:/somelocation/tablename"), False),
],
)
def test_is_db_dataset(table, db_dataset):
assert table.is_databricks_dataset == db_dataset
assert (table.what == What.DB_DATASET) == db_dataset


@pytest.mark.parametrize(
'table,supported',
[
(Table("a", "b", "c", "EXTERNAL", "DELTA", location="dbfs:/somelocation/tablename"), True),
(Table("a", "b", "c", "EXTERNAL", "CSV", location="dbfs:/somelocation/tablename"), True),
(Table("a", "b", "c", "EXTERNAL", "TEXT", location="dbfs:/somelocation/tablename"), True),
(Table("a", "b", "c", "EXTERNAL", "ORC", location="dbfs:/somelocation/tablename"), True),
(Table("a", "b", "c", "EXTERNAL", "JSON", location="dbfs:/somelocation/tablename"), True),
(Table("a", "b", "c", "EXTERNAL", "AVRO", location="dbfs:/somelocation/tablename"), False),
],
)
def test_is_supported_for_sync(table, supported):
assert table.is_format_supported_for_sync == supported


@pytest.mark.parametrize(
'table,what',
[
(Table("a", "b", "c", "EXTERNAL", "DELTA", location="s3://external_location/table"), What.EXTERNAL_SYNC),
(
Table("a", "b", "c", "EXTERNAL", "UNSUPPORTED_FORMAT", location="s3://external_location/table"),
What.EXTERNAL_NO_SYNC,
),
(Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/somelocation/tablename"), What.DBFS_ROOT_DELTA),
(Table("a", "b", "c", "MANAGED", "PARQUET", location="dbfs:/somelocation/tablename"), What.DBFS_ROOT_NON_DELTA),
(Table("a", "b", "c", "VIEW", "VIEW", view_text="select * from some_table"), What.VIEW),
(
Table("a", "b", "c", "MANAGED", "DELTA", location="dbfs:/databricks-datasets/somelocation/tablename"),
What.DB_DATASET,
),
],
)
def test_table_what(table, what):
assert table.what == what