diff --git a/src/databricks/labs/lsql/dashboards.py b/src/databricks/labs/lsql/dashboards.py index a7a18a2f..51c25c02 100644 --- a/src/databricks/labs/lsql/dashboards.py +++ b/src/databricks/labs/lsql/dashboards.py @@ -1,6 +1,7 @@ +import dataclasses import json from pathlib import Path -from typing import ClassVar, Protocol, runtime_checkable +from typing import TypeVar import sqlglot import yaml @@ -24,41 +25,56 @@ Widget, ) - -@runtime_checkable -class _DataclassInstance(Protocol): - __dataclass_fields__: ClassVar[dict] +T = TypeVar("T") class Dashboards: def __init__(self, ws: WorkspaceClient): self._ws = ws - def get_dashboard(self, dashboard_path: str): + def get_dashboard(self, dashboard_path: str) -> Dashboard: with self._ws.workspace.download(dashboard_path, format=ExportFormat.SOURCE) as f: raw = f.read().decode("utf-8") as_dict = json.loads(raw) return Dashboard.from_dict(as_dict) - def save_to_folder(self, dashboard_path: str, local_path: Path): + def save_to_folder(self, dashboard: Dashboard, local_path: Path) -> Dashboard: local_path.mkdir(parents=True, exist_ok=True) - dashboard = self.get_dashboard(dashboard_path) - better_names = {} + dashboard = self._with_better_names(dashboard) for dataset in dashboard.datasets: - name = dataset.display_name - better_names[dataset.name] = name - query_path = local_path / f"{name}.sql" - sql_query = dataset.query - self._format_sql_file(sql_query, query_path) - lvdash_yml = local_path / "lvdash.yml" - with lvdash_yml.open("w") as f: - first_page = dashboard.pages[0] - self._replace_names(first_page, better_names) - page = first_page.as_dict() - yaml.safe_dump(page, f) - assert True + query = self._format_query(dataset.query) + with (local_path / f"{dataset.name}.sql").open("w") as f: + f.write(query) + for page in dashboard.pages: + with (local_path / f"{page.name}.yml").open("w") as f: + yaml.safe_dump(page.as_dict(), f) + return dashboard + + @staticmethod + def _format_query(query: str) -> str: + try: + parsed_query = sqlglot.parse(query) + except sqlglot.ParseError: + return query + statements = [] + for statement in parsed_query: + if statement is None: + continue + # see https://sqlglot.com/sqlglot/generator.html#Generator + statements.append( + statement.sql( + dialect="databricks", + normalize=True, # normalize identifiers to lowercase + pretty=True, # format the produced SQL string + normalize_functions="upper", # normalize function names to uppercase + max_text_width=80, # wrap text at 120 characters + ) + ) + formatted_query = ";\n".join(statements) + return formatted_query - def create_dashboard(self, dashboard_folder: Path) -> Dashboard: + @staticmethod + def create_dashboard(dashboard_folder: Path) -> Dashboard: """Create a dashboard from code, i.e. configuration and queries.""" datasets, layouts = [], [] for query_path in dashboard_folder.glob("*.sql"): @@ -99,42 +115,44 @@ def deploy_dashboard( ) return dashboard - def _format_sql_file(self, sql_query, query_path): - with query_path.open("w") as f: - try: - for statement in sqlglot.parse(sql_query): - # see https://sqlglot.com/sqlglot/generator.html#Generator - pretty = statement.sql( - dialect="databricks", - normalize=True, # normalize identifiers to lowercase - pretty=True, # format the produced SQL string - normalize_functions="upper", # normalize function names to uppercase - max_text_width=80, # wrap text at 120 characters - ) - f.write(f"{pretty};\n") - except sqlglot.ParseError: - f.write(sql_query) + def _with_better_names(self, dashboard: Dashboard) -> Dashboard: + """Replace names with human-readable names.""" + better_names = {} + for dataset in dashboard.datasets: + if dataset.display_name is not None: + better_names[dataset.name] = dataset.display_name + for page in dashboard.pages: + if page.display_name is not None: + better_names[page.name] = page.display_name + return self._replace_names(dashboard, better_names) - def _replace_names(self, node: _DataclassInstance, better_names: dict[str, str]): - # walk evely dataclass instance recursively and replace names - if isinstance(node, _DataclassInstance): - for field in node.__dataclass_fields__.values(): + def _replace_names(self, node: T, better_names: dict[str, str]) -> T: + # walk every dataclass instance recursively and replace names + if dataclasses.is_dataclass(node): + for field in dataclasses.fields(node): value = getattr(node, field.name) if isinstance(value, list): setattr(node, field.name, [self._replace_names(item, better_names) for item in value]) - elif isinstance(value, _DataclassInstance): + elif dataclasses.is_dataclass(value): setattr(node, field.name, self._replace_names(value, better_names)) - if isinstance(node, Query): + if isinstance(node, Dataset): + node.name = better_names.get(node.name, node.name) + elif isinstance(node, Page): + node.name = better_names.get(node.name, node.name) + elif isinstance(node, Query): node.dataset_name = better_names.get(node.dataset_name, node.dataset_name) elif isinstance(node, NamedQuery) and node.query: # 'dashboards/01eeb077e38c17e6ba3511036985960c/datasets/01eeb081882017f6a116991d124d3068_...' if node.name.startswith("dashboards/"): parts = [node.query.dataset_name] - for field in node.query.fields: - parts.append(field.name) + for query_field in node.query.fields: + parts.append(query_field.name) new_name = "_".join(parts) better_names[node.name] = new_name node.name = better_names.get(node.name, node.name) elif isinstance(node, ControlFieldEncoding): node.query_name = better_names.get(node.query_name, node.query_name) + elif isinstance(node, Widget): + if node.spec is not None: + node.name = node.spec.as_dict().get("widgetType", node.name) return node diff --git a/tests/integration/queries/counter.sql b/tests/integration/dashboards/one_counter/counter.sql similarity index 100% rename from tests/integration/queries/counter.sql rename to tests/integration/dashboards/one_counter/counter.sql diff --git a/tests/integration/test_dashboards.py b/tests/integration/test_dashboards.py index 5cd3fecd..97e2048d 100644 --- a/tests/integration/test_dashboards.py +++ b/tests/integration/test_dashboards.py @@ -1,11 +1,10 @@ import json -from dataclasses import fields, is_dataclass from pathlib import Path import pytest from databricks.labs.lsql.dashboards import Dashboards -from databricks.labs.lsql.lakeview.model import CounterSpec, Dashboard +from databricks.labs.lsql.lakeview.model import Dashboard @pytest.fixture @@ -20,68 +19,23 @@ def dashboard_id(ws, make_random): ws.lakeview.trash(dashboard.dashboard_id) -def test_load_dashboard(ws): - dashboard = Dashboards(ws) - src = "/Workspace/Users/serge.smertin@databricks.com/Trivial Dashboard.lvdash.json" - dst = Path(__file__).parent / "sample" - dashboard.save_to_folder(src, dst) - - -def test_dashboard_creates_one_dataset_per_query(ws): - queries = Path(__file__).parent / "queries" - dashboard = Dashboards(ws).create_dashboard(queries) - assert len(dashboard.datasets) == len([query for query in queries.glob("*.sql")]) - - -def test_dashboard_creates_one_counter_widget_per_query(ws): - queries = Path(__file__).parent / "queries" - dashboard = Dashboards(ws).create_dashboard(queries) - - counter_widgets = [] - for page in dashboard.pages: - for layout in page.layout: - if isinstance(layout.widget.spec, CounterSpec): - counter_widgets.append(layout.widget) - - assert len(counter_widgets) == len([query for query in queries.glob("*.sql")]) - - -def replace_recursively(dataklass, replace_fields): - for field in fields(dataklass): - value = getattr(dataklass, field.name) - if is_dataclass(value): - new_value = replace_recursively(value, replace_fields) - elif isinstance(value, list): - new_value = [replace_recursively(v, replace_fields) for v in value] - elif isinstance(value, tuple): - new_value = (replace_recursively(v, replace_fields) for v in value) - else: - new_value = replace_fields.get(field.name, value) - setattr(dataklass, field.name, new_value) - return dataklass - - -def test_dashboard_deploys_dashboard(ws, dashboard_id): - queries = Path(__file__).parent / "queries" - dashboard_client = Dashboards(ws) - lakeview_dashboard = dashboard_client.create_dashboard(queries) - - dashboard = dashboard_client.deploy_dashboard(lakeview_dashboard, dashboard_id=dashboard_id) - deployed_lakeview_dashboard = dashboard_client.get_dashboard(dashboard.path) - - replace_name = {"name": "test", "dataset_name": "test"} # Dynamically created names - lakeview_dashboard_wo_name = replace_recursively(lakeview_dashboard, replace_name) - deployed_lakeview_dashboard_wo_name = replace_recursively(deployed_lakeview_dashboard, replace_name) - - assert lakeview_dashboard_wo_name.as_dict() == deployed_lakeview_dashboard_wo_name.as_dict() - - def test_dashboards_deploys_exported_dashboard_definition(ws, dashboard_id): dashboard_file = Path(__file__).parent / "dashboards" / "dashboard.json" with dashboard_file.open("r") as f: lakeview_dashboard = Dashboard.from_dict(json.load(f)) - dashboard_client = Dashboards(ws) - dashboard = dashboard_client.deploy_dashboard(lakeview_dashboard, dashboard_id=dashboard_id) + dashboards = Dashboards(ws) + dashboard = dashboards.deploy_dashboard(lakeview_dashboard, dashboard_id=dashboard_id) assert ws.lakeview.get(dashboard.dashboard_id) + + +def test_dashboard_deploys_dashboard_the_same_as_created_dashboard(ws, dashboard_id): + queries = Path(__file__).parent / "dashboards" / "one_counter" + dashboards = Dashboards(ws) + dashboard = dashboards.create_dashboard(queries) + + sdk_dashboard = dashboards.deploy_dashboard(dashboard, dashboard_id=dashboard_id) + new_dashboard = dashboards.get_dashboard(sdk_dashboard.path) + + assert dashboards._with_better_names(dashboard).as_dict() == dashboards._with_better_names(new_dashboard).as_dict() diff --git a/tests/unit/queries/counter.sql b/tests/unit/queries/counter.sql new file mode 100644 index 00000000..b2ab78e2 --- /dev/null +++ b/tests/unit/queries/counter.sql @@ -0,0 +1 @@ +SELECT 6217 AS count \ No newline at end of file diff --git a/tests/unit/test_dashboards.py b/tests/unit/test_dashboards.py new file mode 100644 index 00000000..d2116fab --- /dev/null +++ b/tests/unit/test_dashboards.py @@ -0,0 +1,170 @@ +from pathlib import Path +from unittest.mock import create_autospec + +import pytest +from databricks.sdk import WorkspaceClient + +from databricks.labs.lsql.dashboards import Dashboards +from databricks.labs.lsql.lakeview import ( + CounterEncodingMap, + CounterSpec, + Dashboard, + Dataset, + Layout, + NamedQuery, + Page, + Position, + Query, + Widget, +) + + +def test_dashboards_saves_sql_files_to_folder(tmp_path): + ws = create_autospec(WorkspaceClient) + queries = Path(__file__).parent / "queries" + dashboard = Dashboards(ws).create_dashboard(queries) + + Dashboards(ws).save_to_folder(dashboard, tmp_path) + + assert len(list(tmp_path.glob("*.sql"))) == len(dashboard.datasets) + ws.assert_not_called() + + +def test_dashboards_saves_yml_files_to_folder(tmp_path): + ws = create_autospec(WorkspaceClient) + queries = Path(__file__).parent / "queries" + dashboard = Dashboards(ws).create_dashboard(queries) + + Dashboards(ws).save_to_folder(dashboard, tmp_path) + + assert len(list(tmp_path.glob("*.yml"))) == len(dashboard.pages) + ws.assert_not_called() + + +def test_dashboards_creates_one_dataset_per_query(): + ws = create_autospec(WorkspaceClient) + queries = Path(__file__).parent / "queries" + dashboard = Dashboards(ws).create_dashboard(queries) + assert len(dashboard.datasets) == len([query for query in queries.glob("*.sql")]) + + +def test_dashboards_creates_one_counter_widget_per_query(): + ws = create_autospec(WorkspaceClient) + queries = Path(__file__).parent / "queries" + dashboard = Dashboards(ws).create_dashboard(queries) + + counter_widgets = [] + for page in dashboard.pages: + for layout in page.layout: + if isinstance(layout.widget.spec, CounterSpec): + counter_widgets.append(layout.widget) + + assert len(counter_widgets) == len([query for query in queries.glob("*.sql")]) + + +def test_dashboards_deploy_raises_value_error_with_missing_display_name_and_dashboard_id(): + ws = create_autospec(WorkspaceClient) + dashboards = Dashboards(ws) + dashboard = Dashboard([], []) + with pytest.raises(ValueError): + dashboards.deploy_dashboard(dashboard) + ws.assert_not_called() + + +def test_dashboards_deploy_raises_value_error_with_both_display_name_and_dashboard_id(): + ws = create_autospec(WorkspaceClient) + dashboards = Dashboards(ws) + dashboard = Dashboard([], []) + with pytest.raises(ValueError): + dashboards.deploy_dashboard(dashboard, display_name="test", dashboard_id="test") + ws.assert_not_called() + + +def test_dashboards_deploy_calls_create_with_display_name(): + ws = create_autospec(WorkspaceClient) + dashboards = Dashboards(ws) + dashboard = Dashboard([], []) + dashboards.deploy_dashboard(dashboard, display_name="test") + + ws.lakeview.create.assert_called_once() + ws.lakeview.update.assert_not_called() + + +def test_dashboards_deploy_calls_update_with_dashboard_id(): + ws = create_autospec(WorkspaceClient) + dashboards = Dashboards(ws) + dashboard = Dashboard([], []) + dashboards.deploy_dashboard(dashboard, dashboard_id="test") + + ws.lakeview.create.assert_not_called() + ws.lakeview.update.assert_called_once() + + +def test_dashboards_save_to_folder_replaces_dataset_names_with_display_names(tmp_path): + ws = create_autospec(WorkspaceClient) + dashboards = Dashboards(ws) + + datasets = [Dataset(name="ugly", query="SELECT 1", display_name="pretty")] + dashboard = dashboards.save_to_folder(Dashboard(datasets, []), tmp_path) + + assert all(dataset.name == "pretty" for dataset in dashboard.datasets) + ws.assert_not_called() + + +def test_dashboards_save_to_folder_replaces_page_names_with_display_names(tmp_path): + ws = create_autospec(WorkspaceClient) + dashboards = Dashboards(ws) + + pages = [Page(name="ugly", layout=[], display_name="pretty")] + dashboard = dashboards.save_to_folder(Dashboard([], pages), tmp_path) + + assert all(page.name == "pretty" for page in dashboard.pages) + ws.assert_not_called() + + +@pytest.fixture +def ugly_dashboard() -> Dashboard: + datasets = [Dataset(name="ugly", query="SELECT 1", display_name="pretty")] + + query = Query(dataset_name="ugly", fields=[]) + named_query = NamedQuery(name="main_query", query=query) + counter_spec = CounterSpec(CounterEncodingMap()) + widget = Widget(name="ugly", queries=[named_query], spec=counter_spec) + position = Position(x=0, y=0, width=1, height=1) + layout = Layout(widget=widget, position=position) + pages = [Page(name="ugly", layout=[layout], display_name="pretty")] + + dashboard = Dashboard(datasets, pages) + return dashboard + + +def test_dashboards_save_to_folder_replaces_query_name_with_dataset_name(ugly_dashboard, tmp_path): + ws = create_autospec(WorkspaceClient) + dashboards = Dashboards(ws) + + dashboard = dashboards.save_to_folder(ugly_dashboard, tmp_path) + + queries = [] + for page in dashboard.pages: + for layout in page.layout: + for named_query in layout.widget.queries: + queries.append(named_query.query) + + assert all(query.dataset_name == "pretty" for query in queries) + ws.assert_not_called() + + +def test_dashboards_save_to_folder_replaces_counter_names(ugly_dashboard, tmp_path): + ws = create_autospec(WorkspaceClient) + dashboards = Dashboards(ws) + + dashboard = dashboards.save_to_folder(ugly_dashboard, tmp_path) + + counters = [] + for page in dashboard.pages: + for layout in page.layout: + if isinstance(layout.widget.spec, CounterSpec): + counters.append(layout.widget) + + assert all(counter.name == "counter" for counter in counters) + ws.assert_not_called()