diff --git a/src/databricks/labs/lsql/dashboards.py b/src/databricks/labs/lsql/dashboards.py index e8db354d..ef1b7b88 100644 --- a/src/databricks/labs/lsql/dashboards.py +++ b/src/databricks/labs/lsql/dashboards.py @@ -379,23 +379,6 @@ def _get_abstract_syntax_tree(self) -> sqlglot.Expression | None: logger.warning(f"Parsing {self.content}: {e}") return None - def _get_query(self) -> str: - _, query = self.metadata.handler.split() - if self.query_transformer is None: - return query - syntax_tree = self._get_abstract_syntax_tree() - if syntax_tree is None: - return query - query_transformed = syntax_tree.transform(self.query_transformer).sql( - dialect=self._DIALECT, - # A transformer requires to (re)define how to output SQL - normalize=True, # normalize identifiers to lowercase - pretty=True, # format the produced SQL string - normalize_functions="upper", # normalize function names to uppercase - max_text_width=80, # wrap text at 80 characters - ) - return query_transformed - def _find_fields(self) -> list[Field]: """Find the fields in a query. @@ -416,6 +399,52 @@ def _find_fields(self) -> list[Field]: fields.append(field) return fields + def replace_database( + self, + catalog: str | None = None, + database: str | None = None, + *, + catalog_to_replace: str | None = None, + database_to_replace: str | None = None, + ) -> "QueryTile": + """Replace the database in the query. + + Parameters : + catalog : str + The value to replace the catalog with + database : str + The value to replace the database with + catalog_to_replace : str | None (default: None) + The catalog to replace, if None, all catalogs are replaced + database_to_replace : str | None (default: None) + The database to replace, if None, all databases are replaced + """ + + def replace_catalog_and_database_in_query(node: sqlglot.Expression) -> sqlglot.Expression: + if isinstance(node, sqlglot.exp.Table): + if node.args.get("catalog") is not None and ( + catalog_to_replace is None or getattr(node.args.get("catalog"), "this", "") == catalog_to_replace + ): + node.args["catalog"].set("this", catalog) + if node.args.get("db") is not None and ( + database_to_replace is None or getattr(node.args.get("db"), "this", "") == database_to_replace + ): + node.args["db"].set("this", database) + return node + + syntax_tree = self._get_abstract_syntax_tree() + if syntax_tree is None: + return dataclasses.replace(self, _content=self.content) + content_transformed = syntax_tree.transform(replace_catalog_and_database_in_query).sql( + dialect=self._DIALECT, + # A transformer requires to (re)define how to output SQL + normalize=True, # normalize identifiers to lowercase + pretty=True, # format the produced SQL string + normalize_functions="upper", # normalize function names to uppercase + max_text_width=80, # wrap text at 80 characters + ) + return dataclasses.replace(self, _content=content_transformed) + def infer_spec_type(self) -> type[WidgetSpec] | None: """Infer the spec type from the query.""" if self.metadata.widget_type != WidgetType.AUTO: @@ -429,8 +458,7 @@ def infer_spec_type(self) -> type[WidgetSpec] | None: def get_datasets(self) -> Iterable[Dataset]: """Get the dataset belonging to the query.""" - query = self._get_query() - dataset = Dataset(name=self.metadata.id, display_name=self.metadata.id, query=query) + dataset = Dataset(name=self.metadata.id, display_name=self.metadata.id, query=self.content) yield dataset def _merge_nested_dictionaries(self, left: dict, right: dict) -> dict: @@ -607,8 +635,6 @@ class DashboardMetadata: """The metadata defining a lakeview dashboard""" display_name: str # The dashboard display name - # A sqlglot transformer applied on the queries before creating the tiles. If None, no transformation is applied - query_transformer: Callable[[sqlglot.Expression], sqlglot.Expression] | None = None _tiles: list[Tile] = dataclasses.field(default_factory=list) # The dashboard tiles @@ -627,8 +653,6 @@ def tiles(self) -> list[Tile]: tiles_with_order.append((order, tile)) tiles, position = [], Position(0, 0, 0, 0) # Position of first tile for _, tile in sorted(tiles_with_order, key=lambda el: (el[0], el[1].metadata.id)): - if isinstance(tile, QueryTile): - tile.query_transformer = self.query_transformer tile.place_after(position) tiles.append(tile) position = tile.position @@ -668,40 +692,15 @@ def _merge_tile_metadatas(cls, left: list[TileMetadata], right: list[TileMetadat metadatas.append(metadata) return metadatas - def replace_database( - self, - catalog: str | None = None, - database: str | None = None, - *, - catalog_to_replace: str | None = None, - database_to_replace: str | None = None, - ) -> "DashboardMetadata": - """Replace the database in the queries. - - Parameters : - catalog : str - The value to replace the catalog with - database : str - The value to replace the database with - catalog_to_replace : str | None (default: None) - The catalog to replace, if None, all catalogs are replaced - database_to_replace : str | None (default: None) - The database to replace, if None, all databases are replaced - """ - - def replace_catalog_and_database_in_query(node: sqlglot.Expression) -> sqlglot.Expression: - if isinstance(node, sqlglot.exp.Table): - if node.args.get("catalog") is not None and ( - catalog_to_replace is None or getattr(node.args.get("catalog"), "this", "") == catalog_to_replace - ): - node.args["catalog"].set("this", catalog) - if node.args.get("db") is not None and ( - database_to_replace is None or getattr(node.args.get("db"), "this", "") == database_to_replace - ): - node.args["db"].set("this", database) - return node - - return dataclasses.replace(self, query_transformer=replace_catalog_and_database_in_query) + def replace_database(self, *args, **kwargs) -> "DashboardMetadata": + """Wrapper around :method:QueryTile.replace_database""" + tiles: list[Tile] = [] + for tile in self.tiles: + if isinstance(tile, QueryTile): + tiles.append(tile.replace_database(*args, **kwargs)) + else: + tiles.append(tile) + return dataclasses.replace(self, _tiles=tiles) def get_datasets(self) -> list[Dataset]: """Get the datasets for the dashboard.""" diff --git a/tests/unit/test_dashboards.py b/tests/unit/test_dashboards.py index c20efcce..9c3c4ae0 100644 --- a/tests/unit/test_dashboards.py +++ b/tests/unit/test_dashboards.py @@ -741,7 +741,7 @@ def test_query_tile_keeps_original_query(tmp_path): query_path.write_text(query) tile_metadata = TileMetadata.from_path(query_path) - query_tile = QueryTile(tile_metadata) + query_tile = QueryTile.from_tile_metadata(tile_metadata) dataset = next(query_tile.get_datasets())