Skip to content

Commit

Permalink
Move replace_database to QueryTile
Browse files Browse the repository at this point in the history
  • Loading branch information
JCZuurmond committed Jul 10, 2024
1 parent b97a63c commit cc4ce27
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 58 deletions.
113 changes: 56 additions & 57 deletions src/databricks/labs/lsql/dashboards.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,23 +379,6 @@ def _get_abstract_syntax_tree(self) -> sqlglot.Expression | None:
logger.warning(f"Parsing {self.content}: {e}")
return None

def _get_query(self) -> str:
_, query = self.metadata.handler.split()
if self.query_transformer is None:
return query
syntax_tree = self._get_abstract_syntax_tree()
if syntax_tree is None:
return query
query_transformed = syntax_tree.transform(self.query_transformer).sql(
dialect=self._DIALECT,
# A transformer requires to (re)define how to output SQL
normalize=True, # normalize identifiers to lowercase
pretty=True, # format the produced SQL string
normalize_functions="upper", # normalize function names to uppercase
max_text_width=80, # wrap text at 80 characters
)
return query_transformed

def _find_fields(self) -> list[Field]:
"""Find the fields in a query.
Expand All @@ -416,6 +399,52 @@ def _find_fields(self) -> list[Field]:
fields.append(field)
return fields

def replace_database(
self,
catalog: str | None = None,
database: str | None = None,
*,
catalog_to_replace: str | None = None,
database_to_replace: str | None = None,
) -> "QueryTile":
"""Replace the database in the query.
Parameters :
catalog : str
The value to replace the catalog with
database : str
The value to replace the database with
catalog_to_replace : str | None (default: None)
The catalog to replace, if None, all catalogs are replaced
database_to_replace : str | None (default: None)
The database to replace, if None, all databases are replaced
"""

def replace_catalog_and_database_in_query(node: sqlglot.Expression) -> sqlglot.Expression:
if isinstance(node, sqlglot.exp.Table):
if node.args.get("catalog") is not None and (
catalog_to_replace is None or getattr(node.args.get("catalog"), "this", "") == catalog_to_replace
):
node.args["catalog"].set("this", catalog)
if node.args.get("db") is not None and (
database_to_replace is None or getattr(node.args.get("db"), "this", "") == database_to_replace
):
node.args["db"].set("this", database)
return node

syntax_tree = self._get_abstract_syntax_tree()
if syntax_tree is None:
return dataclasses.replace(self, _content=self.content)
content_transformed = syntax_tree.transform(replace_catalog_and_database_in_query).sql(
dialect=self._DIALECT,
# A transformer requires to (re)define how to output SQL
normalize=True, # normalize identifiers to lowercase
pretty=True, # format the produced SQL string
normalize_functions="upper", # normalize function names to uppercase
max_text_width=80, # wrap text at 80 characters
)
return dataclasses.replace(self, _content=content_transformed)

def infer_spec_type(self) -> type[WidgetSpec] | None:
"""Infer the spec type from the query."""
if self.metadata.widget_type != WidgetType.AUTO:
Expand All @@ -429,8 +458,7 @@ def infer_spec_type(self) -> type[WidgetSpec] | None:

def get_datasets(self) -> Iterable[Dataset]:
"""Get the dataset belonging to the query."""
query = self._get_query()
dataset = Dataset(name=self.metadata.id, display_name=self.metadata.id, query=query)
dataset = Dataset(name=self.metadata.id, display_name=self.metadata.id, query=self.content)
yield dataset

def _merge_nested_dictionaries(self, left: dict, right: dict) -> dict:
Expand Down Expand Up @@ -607,8 +635,6 @@ class DashboardMetadata:
"""The metadata defining a lakeview dashboard"""

display_name: str # The dashboard display name
# A sqlglot transformer applied on the queries before creating the tiles. If None, no transformation is applied
query_transformer: Callable[[sqlglot.Expression], sqlglot.Expression] | None = None

_tiles: list[Tile] = dataclasses.field(default_factory=list) # The dashboard tiles

Expand All @@ -627,8 +653,6 @@ def tiles(self) -> list[Tile]:
tiles_with_order.append((order, tile))
tiles, position = [], Position(0, 0, 0, 0) # Position of first tile
for _, tile in sorted(tiles_with_order, key=lambda el: (el[0], el[1].metadata.id)):
if isinstance(tile, QueryTile):
tile.query_transformer = self.query_transformer
tile.place_after(position)
tiles.append(tile)
position = tile.position
Expand Down Expand Up @@ -668,40 +692,15 @@ def _merge_tile_metadatas(cls, left: list[TileMetadata], right: list[TileMetadat
metadatas.append(metadata)
return metadatas

def replace_database(
self,
catalog: str | None = None,
database: str | None = None,
*,
catalog_to_replace: str | None = None,
database_to_replace: str | None = None,
) -> "DashboardMetadata":
"""Replace the database in the queries.
Parameters :
catalog : str
The value to replace the catalog with
database : str
The value to replace the database with
catalog_to_replace : str | None (default: None)
The catalog to replace, if None, all catalogs are replaced
database_to_replace : str | None (default: None)
The database to replace, if None, all databases are replaced
"""

def replace_catalog_and_database_in_query(node: sqlglot.Expression) -> sqlglot.Expression:
if isinstance(node, sqlglot.exp.Table):
if node.args.get("catalog") is not None and (
catalog_to_replace is None or getattr(node.args.get("catalog"), "this", "") == catalog_to_replace
):
node.args["catalog"].set("this", catalog)
if node.args.get("db") is not None and (
database_to_replace is None or getattr(node.args.get("db"), "this", "") == database_to_replace
):
node.args["db"].set("this", database)
return node

return dataclasses.replace(self, query_transformer=replace_catalog_and_database_in_query)
def replace_database(self, *args, **kwargs) -> "DashboardMetadata":
"""Wrapper around :method:QueryTile.replace_database"""
tiles: list[Tile] = []
for tile in self.tiles:
if isinstance(tile, QueryTile):
tiles.append(tile.replace_database(*args, **kwargs))
else:
tiles.append(tile)
return dataclasses.replace(self, _tiles=tiles)

def get_datasets(self) -> list[Dataset]:
"""Get the datasets for the dashboard."""
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_dashboards.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def test_query_tile_keeps_original_query(tmp_path):
query_path.write_text(query)

tile_metadata = TileMetadata.from_path(query_path)
query_tile = QueryTile(tile_metadata)
query_tile = QueryTile.from_tile_metadata(tile_metadata)

dataset = next(query_tile.get_datasets())

Expand Down

0 comments on commit cc4ce27

Please sign in to comment.