Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@

import pandas as pd
import requests
import sqlparse
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from deprecation import deprecated
Expand All @@ -57,14 +56,19 @@
from sqlalchemy.sql import literal_column, quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextClause
from sqlalchemy.types import TypeEngine
from sqlparse.tokens import CTE

from superset import db, sql_parse
from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
from superset.sql.parse import BaseSQLStatement, LimitMethod, SQLScript, Table
from superset.sql.parse import (
BaseSQLStatement,
LimitMethod,
SQLScript,
SQLStatement,
Table,
)
from superset.superset_typing import (
OAuth2ClientConfig,
OAuth2State,
Expand Down Expand Up @@ -1124,18 +1128,9 @@ def get_cte_query(cls, sql: str) -> str | None:

"""
if not cls.allows_cte_in_subquery:
stmt = sqlparse.parse(sql)[0]

# The first meaningful token for CTE will be with WITH
idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True)
if not (token and token.ttype == CTE):
return None
idx, token = stmt.token_next(idx)
idx = stmt.token_index(token) + 1

# extract rest of the SQLs after CTE
remainder = "".join(str(token) for token in stmt.tokens[idx:]).strip()
return f"WITH {token.value},\n{cls.cte_alias} AS (\n{remainder}\n)"
statement = SQLStatement(sql, engine=cls.engine)
if statement.has_cte():
return statement.as_cte(cls.cte_alias).format()

return None

Expand Down
47 changes: 47 additions & 0 deletions superset/sql/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,23 @@ def set_limit_value(
"""
raise NotImplementedError()

def has_cte(self) -> bool:
"""
Check if the statement has a CTE.

:return: True if the statement has a CTE at the top level.
"""
raise NotImplementedError()

def as_cte(self, alias: str = "__cte") -> SQLStatement:
"""
Rewrite the statement as a CTE.

:param alias: The alias to use for the CTE.
:return: A new SQLStatement with the CTE.
"""
raise NotImplementedError()

def __str__(self) -> str:
return self.format()

Expand Down Expand Up @@ -526,6 +543,36 @@ def set_limit_value(
else: # method == LimitMethod.FETCH_MANY
pass

def has_cte(self) -> bool:
"""
Check if the statement has a CTE.

:return: True if the statement has a CTE at the top level.
"""
return "with" in self._parsed.args

def as_cte(self, alias: str = "__cte") -> SQLStatement:
"""
Rewrite the statement as a CTE.

This is needed by MS SQL when the query includes CTEs. In that case the CTEs
need to be moved to the top of the query when we wrap it as a subquery when
building charts.

:param alias: The alias to use for the CTE.
:return: A new SQLStatement with the CTE.
"""
existing_ctes = self._parsed.args["with"].expressions if self.has_cte() else []
self._parsed.args["with"] = None
new_cte = exp.CTE(
this=self._parsed.copy(),
alias=exp.TableAlias(this=exp.Identifier(this=alias)),
)
return SQLStatement(
ast=exp.With(expressions=[*existing_ctes, new_cte], this=None),
engine=self.engine,
)


class KQLSplitState(enum.Enum):
"""
Expand Down
26 changes: 15 additions & 11 deletions tests/unit_tests/db_engine_specs/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,17 +217,21 @@ def test_column_datatype_to_string(original: TypeEngine, expected: str) -> None:
select * from currency union all select * from currency_2
"""
),
dedent(
"""WITH currency as (
select 'INR' as cur
),
currency_2 as (
select 'EUR' as cur
),
__cte AS (
select * from currency union all select * from currency_2
)"""
),
"""WITH currency AS (
SELECT
'INR' AS cur
), currency_2 AS (
SELECT
'EUR' AS cur
), __cte AS (
SELECT
*
FROM currency
UNION ALL
SELECT
*
FROM currency_2
)""",
),
(
"SELECT 1 as cnt",
Expand Down
106 changes: 106 additions & 0 deletions tests/unit_tests/sql/parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,3 +1519,109 @@ def test_set_kql_limit_value(kql: str, limit: int, expected: str) -> None:
statement = KustoKQLStatement(kql, "kustokql")
statement.set_limit_value(limit)
assert statement.format() == expected


@pytest.mark.parametrize(
"sql, engine, expected",
[
("SELECT 1", "postgresql", False),
("SELECT 1 AS cnt", "postgresql", False),
(
"""
SELECT 'INR' AS cur
UNION
SELECT 'USD' AS cur
UNION
SELECT 'EUR' AS cur
""",
"postgresql",
False,
),
("WITH cte AS (SELECT 1) SELECT * FROM cte", "postgresql", True),
(
"""
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM t2),
z AS (SELECT b AS c FROM t3)
SELECT c FROM z
""",
"postgresql",
True,
),
(
"""
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM x),
z AS (SELECT b AS c FROM y)
SELECT c FROM z
""",
"postgresql",
True,
),
(
"""
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
AS (
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
FROM SalesOrderHeader
WHERE SalesPersonID IS NOT NULL
)
SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear
FROM CTE__test
GROUP BY SalesYear, SalesPersonID
ORDER BY SalesPersonID, SalesYear;
""",
"postgresql",
True,
),
],
)
def test_has_cte(sql: str, engine: str, expected: bool) -> None:
"""
Test that the parser detects CTEs correctly.
"""
assert SQLStatement(sql, engine).has_cte() == expected


@pytest.mark.parametrize(
"sql, engine, expected",
[
(
"SELECT 1",
"postgresql",
"WITH __cte AS (\n SELECT\n 1\n)",
),
(
"""
WITH currency AS (SELECT 'INR' AS cur),
currency_2 AS (SELECT 'USD' AS cur)
SELECT * FROM currency
UNION ALL
SELECT * FROM currency_2
""",
"postgresql",
"""WITH currency AS (
SELECT
'INR' AS cur
), currency_2 AS (
SELECT
'USD' AS cur
), __cte AS (
SELECT
*
FROM currency
UNION ALL
SELECT
*
FROM currency_2
)""",
),
],
)
def test_as_cte(sql: str, engine: str, expected: str) -> None:
"""
Test that we can covert select to CTE.
"""
assert SQLStatement(sql, engine).as_cte().format() == expected
Loading