Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions superset/sqllab/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,31 +238,36 @@ def format_sql(self) -> FlaskResponse:
sql = model["sql"]
template_params = model.get("template_params")
database_id = model.get("database_id")
database_engine = None

# Process Jinja templates if template_params and database_id are provided
if template_params and database_id is not None:
# Process Jinja templates if template_params are provided
if database_id is not None:
database = DatabaseDAO.find_by_id(database_id)

if database:
try:
template_params = (
json.loads(template_params)
if isinstance(template_params, str)
else template_params
)
if template_params:
template_processor = get_template_processor(
database=database
database_engine = database.db_engine_spec.engine

if template_params:
try:
template_params = (
json.loads(template_params)
if isinstance(template_params, str)
else template_params
)
sql = template_processor.process_template(
sql, **template_params
if template_params:
template_processor = get_template_processor(
database=database
)
sql = template_processor.process_template(
sql, **template_params
)
except json.JSONDecodeError:
logger.warning(
"Invalid template parameter %s. Skipping processing",
str(template_params),
)
except json.JSONDecodeError:
logger.warning(
"Invalid template parameter %s. Skipping processing",
str(template_params),
)

result = SQLScript(sql, model.get("engine")).format()
result = SQLScript(sql, model.get("engine", database_engine)).format()
Comment thread
justinpark marked this conversation as resolved.
return self.response(200, result=result)
except ValidationError as error:
return self.response_400(message=error.messages)
Expand Down
33 changes: 31 additions & 2 deletions tests/integration_tests/sql_lab/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from superset.utils import core as utils, json
from superset.models.sql_lab import Query

from superset.sql.parse import SQLScript
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import (
ADMIN_USERNAME,
Expand Down Expand Up @@ -288,10 +289,38 @@ def test_format_sql_request(self):
self.assertDictEqual(resp_data, success_resp) # noqa: PT009
assert rv.status_code == 200

def test_format_sql_request_with_db_id(self):
self.login(ADMIN_USERNAME)
example_db = get_example_database()

# IIF is normalized differently per dialect:
# SQLite preserves IIF(), Postgres/base converts to CASE WHEN.
# Compute the expected result from the actual engine so the test is
# environment-independent.
sql = "select IIF(score > 0, 'positive', 'negative') from my_table"
engine = example_db.db_engine_spec.engine
expected = SQLScript(sql, engine).format()

data = {"sql": sql, "database_id": example_db.id}
rv = self.client.post(
"/api/v1/sqllab/format_sql/",
json=data,
)
resp_data = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert resp_data["result"] == expected

def test_format_sql_request_with_jinja(self):
self.login(ADMIN_USERNAME)
example_db = get_example_database()

# Quoted identifier formatting varies by dialect (e.g., MySQL uses backticks).
# Compute the expected result from the actual engine so the test is
# environment-independent.
rendered_sql = 'select * from "Vehicle Sales"'
engine = example_db.db_engine_spec.engine
expected = SQLScript(rendered_sql, engine).format()

data = {
"sql": "select * from {{tbl}}",
"database_id": example_db.id,
Expand All @@ -302,10 +331,10 @@ def test_format_sql_request_with_jinja(self):
json=data,
)
resp_data = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
# Verify that Jinja template was processed before formatting
assert "{{tbl}}" not in resp_data["result"]
assert '"Vehicle Sales"' in resp_data["result"]
assert rv.status_code == 200
assert resp_data["result"] == expected

@mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False)
def test_execute_required_params(self):
Expand Down
Loading