Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def is_str_or_adhoc(metric: Metric) -> bool:
return isinstance(metric, str) or is_adhoc_metric(metric)

self.metrics = metrics and [
x if is_str_or_adhoc(x) else x["label"] # type: ignore
x if is_str_or_adhoc(x) else x["label"] # type: ignore[misc,index]
for x in metrics
]

Expand Down Expand Up @@ -285,6 +285,7 @@ def validate(
self._validate_no_have_duplicate_labels()
self._validate_time_offsets()
self._sanitize_filters()
self._sanitize_sql_expressions()
return None
except QueryObjectValidationError as ex:
if raise_exceptions:
Expand Down Expand Up @@ -359,6 +360,95 @@ def _sanitize_filters(self) -> None:
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex

def _sanitize_sql_expressions(self) -> None:
"""
Sanitize SQL expressions in adhoc metrics and orderby for consistent cache keys.

This processes SQL expressions before cache key generation, preventing cache
mismatches due to later processing during query execution.
"""
if not self.datasource or not hasattr(
self.datasource,
"_process_sql_expression",
):
return

# Process adhoc metrics
if self.metrics:
self._sanitize_metrics_expressions()

# Process orderby - these may contain adhoc metrics
if self.orderby:
self._sanitize_orderby_expressions()

def _sanitize_metrics_expressions(self) -> None:
"""
Process SQL expressions in adhoc metrics.
Creates new metric dictionaries to avoid mutating shared references.
"""
# datasource is checked in parent method, assert for type checking
assert self.datasource is not None

if not self.metrics:
return

sanitized_metrics = []
for metric in self.metrics:
if not (is_adhoc_metric(metric) and isinstance(metric, dict)):
sanitized_metrics.append(metric)
continue
if sql_expr := metric.get("sqlExpression"):
processed = self.datasource._process_select_expression(
expression=sql_expr,
database_id=self.datasource.database_id,
engine=self.datasource.database.backend,
schema=self.datasource.schema,
template_processor=None,
)
if processed and processed != sql_expr:
# Create new dict to avoid mutating shared references
sanitized_metrics.append({**metric, "sqlExpression": processed})
else:
sanitized_metrics.append(metric)
else:
sanitized_metrics.append(metric)

self.metrics = sanitized_metrics

def _sanitize_orderby_expressions(self) -> None:
"""
Process SQL expressions in orderby items.
Creates new tuples and dictionaries to avoid mutating shared references.
"""
# datasource is checked in parent method, assert for type checking
assert self.datasource is not None

if not self.orderby:
return

sanitized_orderby = []
for col, ascending in self.orderby:
if not (isinstance(col, dict) and col.get("sqlExpression")):
sanitized_orderby.append((col, ascending))
continue

processed = self.datasource._process_orderby_expression(
expression=col["sqlExpression"],
database_id=self.datasource.database_id,
engine=self.datasource.database.backend,
schema=self.datasource.schema,
template_processor=None,
)
if processed and processed != col["sqlExpression"]:
# Create new dict to avoid mutating shared references
sanitized_orderby.append(
({**col, "sqlExpression": processed}, ascending) # type: ignore[arg-type]
)
else:
sanitized_orderby.append((col, ascending))

self.orderby = sanitized_orderby

def _validate_there_are_no_missing_series(self) -> None:
missing_series = [col for col in self.series_columns if col not in self.columns]
if missing_series:
Expand Down
23 changes: 15 additions & 8 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,10 @@ def adhoc_metric_to_sqla(
schema=self.schema,
template_processor=template_processor,
)
elif template_processor and expression:
# Even if already processed (sanitized), we still need to
# render Jinja templates
expression = template_processor.process_template(expression)

sqla_metric = literal_column(expression)
else:
Expand Down Expand Up @@ -1819,6 +1823,10 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
for metric in metrics:
if utils.is_adhoc_metric(metric):
assert isinstance(metric, dict)
# SQL expressions are sanitized during QueryObject.validate() via
# _sanitize_sql_expressions(), but we still process here to handle
# Jinja templates. sanitize_clause() is idempotent so re-sanitizing
# is safe.
metrics_exprs.append(
self.adhoc_metric_to_sqla(
metric=metric,
Expand Down Expand Up @@ -1855,19 +1863,18 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
col: Union[AdhocMetric, ColumnElement] = orig_col
if isinstance(col, dict):
col = cast(AdhocMetric, col)
if col.get("sqlExpression"):
col["sqlExpression"] = self._process_orderby_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
# SQL expressions are processed during QueryObject.validate() via
# _sanitize_sql_expressions() using ORDER BY wrapping. We pass
# processed=True to skip re-processing and avoid incorrect SELECT
# wrapping that breaks ORDER BY expressions. The removal of the
# _process_orderby_expression() call (which mutated the dict) prevents
# cache key mismatches.
if utils.is_adhoc_metric(col):
# add adhoc sort by column to columns_by_name if not exists
col = self.adhoc_metric_to_sqla(
col,
columns_by_name,
template_processor=template_processor,
processed=True,
)
# use the existing instance, if possible
Expand Down
Loading
Loading