Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table
from sqlalchemy.sql import column, ColumnElement, literal_column, quoted_name, table
from sqlalchemy.sql.elements import ColumnClause, TextClause
from sqlalchemy.sql.expression import Label
from sqlalchemy.sql.selectable import Alias, TableClause
Expand Down Expand Up @@ -1403,13 +1403,19 @@ def get_sqla_table(self) -> TableClause:
# project.dataset.table format
if self.catalog and self.database.db_engine_spec.supports_cross_catalog_queries:
# SQLAlchemy doesn't have built-in catalog support for TableClause,
# so we need to construct the full identifier manually
# so we need to construct the full identifier manually with proper quoting
catalog_quoted = self.quote_identifier(self.catalog)
table_quoted = self.quote_identifier(self.table_name)
Comment on lines +1407 to +1408

This comment was marked as resolved.


if self.schema:
full_name = f"{self.catalog}.{self.schema}.{self.table_name}"
schema_quoted = self.quote_identifier(self.schema)
full_name = f"{catalog_quoted}.{schema_quoted}.{table_quoted}"
else:
full_name = f"{self.catalog}.{self.table_name}"
full_name = f"{catalog_quoted}.{table_quoted}"
Comment on lines +1407 to +1414

This comment was marked as resolved.


return table(full_name)
# Use quoted_name with quote=False to prevent SQLAlchemy from re-quoting
# the already-quoted identifier components
return table(quoted_name(full_name, quote=False))

if self.schema:
return table(self.table_name, schema=self.schema)
Expand Down
154 changes: 150 additions & 4 deletions tests/unit_tests/connectors/sqla/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def test_fetch_metadata_empty_comment_field_handling(mocker: MockerFixture) -> N
"test_table",
"test_project",
"test_dataset",
"test_project.test_dataset.test_table",
'"test_project"."test_dataset"."test_table"',
None,
),
# Database supports cross-catalog queries, catalog only (no schema)
Expand All @@ -625,7 +625,7 @@ def test_fetch_metadata_empty_comment_field_handling(mocker: MockerFixture) -> N
"test_table",
"test_project",
None,
"test_project.test_table",
'"test_project"."test_table"',
None,
),
# Database supports cross-catalog queries, schema only (no catalog)
Expand Down Expand Up @@ -675,12 +675,14 @@ def test_get_sqla_table_with_catalog(
expected_name: str,
expected_schema: str | None,
) -> None:
"""Test that get_sqla_table handles catalog inclusion correctly based on
database cross-catalog support
"""
Test that `get_sqla_table` handles catalog inclusion correctly.
"""
# Mock database with specified cross-catalog support
database = mocker.MagicMock()
database.db_engine_spec.supports_cross_catalog_queries = supports_cross_catalog
# Provide a simple quote_identifier
database.quote_identifier = lambda x: f'"{x}"'

# Create table with specified parameters
table = SqlaTable(
Expand All @@ -696,3 +698,147 @@ def test_get_sqla_table_with_catalog(
# Verify expected table name and schema
assert sqla_table.name == expected_name
assert sqla_table.schema == expected_schema


@pytest.mark.parametrize(
"table_name, catalog, schema, expected_in_sql, not_expected_in_sql",
[
(
"My-Table",
"My-DB",
"My-Schema",
'"My-DB"."My-Schema"."My-Table"',
'"My-DB.My-Schema.My-Table"', # Should NOT be one quoted string
),
(
"ORDERS",
"PROD_DB",
"SALES",
'"PROD_DB"."SALES"."ORDERS"',
'"PROD_DB.SALES.ORDERS"', # Should NOT be one quoted string
),
(
"My Table",
"My DB",
"My Schema",
'"My DB"."My Schema"."My Table"',
'"My DB.My Schema.My Table"', # Should NOT be one quoted string
),
],
)
def test_get_sqla_table_quoting_for_cross_catalog(
mocker: MockerFixture,
table_name: str,
catalog: str | None,
schema: str | None,
expected_in_sql: str,
not_expected_in_sql: str,
) -> None:
"""
Test that `get_sqla_table` properly quotes each component of the identifier.
"""
from sqlalchemy import create_engine, select

# Create a Postgres-like engine to test proper quoting
engine = create_engine("postgresql://user:pass@host/db")

# Mock database with cross-catalog support and proper quote_identifier
database = mocker.MagicMock()
database.db_engine_spec.supports_cross_catalog_queries = True
database.quote_identifier = engine.dialect.identifier_preparer.quote

# Create table
table = SqlaTable(
table_name=table_name,
database=database,
schema=schema,
catalog=catalog,
)

# Get the SQLAlchemy table representation
sqla_table = table.get_sqla_table()
query = select(sqla_table)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))

# The compiled SQL should contain each part quoted separately
assert expected_in_sql in compiled, f"Expected {expected_in_sql} in SQL: {compiled}"
# Should NOT have the entire identifier quoted as one string
assert not_expected_in_sql not in compiled, (
f"Should not have {not_expected_in_sql} in SQL: {compiled}"
)


def test_get_sqla_table_without_cross_catalog_ignores_catalog(
mocker: MockerFixture,
) -> None:
"""
Test that databases without cross-catalog support ignore the catalog field.
"""
from sqlalchemy import create_engine, select

# Create a PostgreSQL engine (doesn't support cross-catalog queries)
engine = create_engine("postgresql://user:pass@localhost/db")

# Mock database without cross-catalog support
database = mocker.MagicMock()
database.db_engine_spec.supports_cross_catalog_queries = False
database.quote_identifier = engine.dialect.identifier_preparer.quote

# Create table with catalog - should be ignored
table = SqlaTable(
table_name="my_table",
database=database,
schema="my_schema",
catalog="my_catalog",
)

# Get the SQLAlchemy table representation
sqla_table = table.get_sqla_table()

# Compile to SQL
query = select(sqla_table)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))

# Should only have schema.table, not catalog.schema.table
assert "my_schema" in compiled
assert "my_table" in compiled
assert "my_catalog" not in compiled


def test_quoted_name_prevents_double_quoting(mocker: MockerFixture) -> None:
"""
Test that `quoted_name(..., quote=False)` does not cause double quoting.
"""
from sqlalchemy import create_engine, select

engine = create_engine("postgresql://user:pass@host/db")

# Mock database
database = mocker.MagicMock()
database.db_engine_spec.supports_cross_catalog_queries = True
database.quote_identifier = engine.dialect.identifier_preparer.quote

# Use uppercase table name to force quoting
table = SqlaTable(
table_name="MY_TABLE",
database=database,
schema="MY_SCHEMA",
catalog="MY_DB",
)

# Get the SQLAlchemy table representation
sqla_table = table.get_sqla_table()

# Compile to SQL
query = select(sqla_table)
compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True}))

# Should NOT have the entire identifier quoted as one:
# BAD: '"MY_DB.MY_SCHEMA.MY_TABLE"'
# This would cause: SQL compilation error: Object '"MY_DB.MY_SCHEMA.MY_TABLE"'
# does not exist
assert '"MY_DB.MY_SCHEMA.MY_TABLE"' not in compiled

# Should have each part quoted separately:
# GOOD: "MY_DB"."MY_SCHEMA"."MY_TABLE"
assert '"MY_DB"."MY_SCHEMA"."MY_TABLE"' in compiled
Loading