Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions superset/commands/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def run(self) -> None:
db.session.rollback()
raise self.import_error()

# pylint: disable=too-many-branches
def validate(self) -> None:
exceptions: List[ValidationError] = []

Expand Down Expand Up @@ -99,6 +100,10 @@ def validate(self) -> None:

# validate objects
for file_name, content in self.contents.items():
# skip directories
if not content:
continue

prefix = file_name.split("/")[0]
schema = self.schemas.get(f"{prefix}/")
if schema:
Expand Down
31 changes: 31 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,37 @@ class CeleryConfig: # pylint: disable=too-few-public-methods
# query costs before they run. These EXPLAIN queries should have a small
# timeout.
SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = 10 # seconds
# The feature is off by default, and currently only supported in Presto and Postgres.
# It also need to be enabled on a per-database basis, by adding the key/value pair
# `cost_estimate_enabled: true` to the database `extra` attribute.
ESTIMATE_QUERY_COST = False
# The cost returned by the databases is a relative value; in order to map the cost to
# a tangible value you need to define a custom formatter that takes into consideration
# your specific infrastructure. For example, you could analyze queries a posteriori by
# running EXPLAIN on them, and compute a histogram of relative costs to present the
# cost as a percentile:
#
# def postgres_query_cost_formatter(
# result: List[Dict[str, Any]]
# ) -> List[Dict[str, str]]:
# # 25, 50, 75% percentiles
# percentile_costs = [100.0, 1000.0, 10000.0]
#
# out = []
# for row in result:
# relative_cost = row["Total cost"]
# percentile = bisect.bisect_left(percentile_costs, relative_cost) + 1
# out.append({
# "Relative cost": relative_cost,
# "Percentile": str(percentile * 25) + "%",
# })
#
# return out
#
# DEFAULT_FEATURE_FLAGS = {
# "ESTIMATE_QUERY_COST": True,
# "QUERY_COST_FORMATTERS_BY_ENGINE": {"postgresql": postgres_query_cost_formatter},
# }

# Flag that controls if limit should be enforced on the CTA (create table as queries).
SQLLAB_CTAS_NO_LIMIT = False
Expand Down
1 change: 1 addition & 0 deletions superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ class ImportV1DatabaseExtraSchema(Schema):
engine_params = fields.Dict(keys=fields.Str(), values=fields.Raw())
metadata_cache_timeout = fields.Dict(keys=fields.Str(), values=fields.Integer())
schemas_allowed_for_csv_upload = fields.List(fields.String)
cost_estimate_enabled = fields.Boolean()


class ImportV1DatabaseSchema(Schema):
Expand Down
44 changes: 31 additions & 13 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom
from sqlalchemy.types import TypeEngine

from superset import app, sql_parse
from superset import app, security_manager, sql_parse
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, Table
Expand Down Expand Up @@ -203,7 +203,7 @@ def is_db_column_type_match(
return any(pattern.match(db_column_type) for pattern in patterns)

@classmethod
def get_allow_cost_estimate(cls, version: Optional[str] = None) -> bool:
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return False

@classmethod
Expand Down Expand Up @@ -790,16 +790,12 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals
return sql

@classmethod
def estimate_statement_cost(
cls, statement: str, database: "Database", cursor: Any, user_name: str
) -> Dict[str, Any]:
def estimate_statement_cost(cls, statement: str, cursor: Any,) -> Dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.

:param statement: A single SQL statement
:param database: Database instance
:param cursor: Cursor instance
:param username: Effective username
:return: Dictionary with different costs
"""
raise Exception("Database does not support cost estimation")
Expand All @@ -816,10 +812,31 @@ def query_cost_formatter(
"""
raise Exception("Database does not support cost estimation")

@classmethod
def process_statement(
cls, statement: str, database: "Database", user_name: str
) -> str:
"""
Process a SQL statement by stripping and mutating it.

:param statement: A single SQL statement
:param database: Database instance
:param username: Effective username
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement)
sql = parsed_query.stripped()

sql_query_mutator = config["SQL_QUERY_MUTATOR"]
if sql_query_mutator:
sql = sql_query_mutator(sql, user_name, security_manager, database)

return sql

@classmethod
def estimate_query_cost(
cls, database: "Database", schema: str, sql: str, source: Optional[str] = None
) -> List[Dict[str, str]]:
) -> List[Dict[str, Any]]:
"""
Estimate the cost of a multiple statement SQL query.

Expand All @@ -828,8 +845,8 @@ def estimate_query_cost(
:param sql: SQL query with possibly multiple statements
:param source: Source of the query (eg, "sql_lab")
"""
database_version = database.get_extra().get("version")
if not cls.get_allow_cost_estimate(database_version):
extra = database.get_extra() or {}
if not cls.get_allow_cost_estimate(extra):
raise Exception("Database does not support cost estimation")

user_name = g.user.username if g.user else None
Expand All @@ -841,10 +858,11 @@ def estimate_query_cost(
with closing(engine.raw_connection()) as conn:
with closing(conn.cursor()) as cursor:
for statement in statements:
processed_statement = cls.process_statement(
statement, database, user_name
)
costs.append(
cls.estimate_statement_cost(
statement, database, cursor, user_name
)
cls.estimate_statement_cost(processed_statement, cursor)
)
return costs

Expand Down
28 changes: 27 additions & 1 deletion superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
from datetime import datetime
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING

from pytz import _FixedOffset # type: ignore
from sqlalchemy.dialects.postgresql.base import PGInspector
Expand Down Expand Up @@ -71,6 +72,31 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
max_column_name_length = 63
try_remove_schema_from_table_name = False

@classmethod
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return True

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
sql = f"EXPLAIN {statement}"
cursor.execute(sql)

result = cursor.fetchone()[0]
match = re.search(r"cost=([\d\.]+)\.\.([\d\.]+)", result)
if match:
return {
"Start-up cost": float(match.group(1)),
"Total cost": float(match.group(2)),
}

return {}

@classmethod
def query_cost_formatter(
cls, raw_cost: List[Dict[str, Any]]
) -> List[Dict[str, str]]:
return [{k: str(v) for k, v in row.items()} for row in raw_cost]

@classmethod
def get_table_names(
cls, database: "Database", inspector: PGInspector, schema: Optional[str]
Expand Down
16 changes: 5 additions & 11 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select

from superset import app, cache_manager, is_feature_enabled, security_manager
from superset import app, cache_manager, is_feature_enabled
from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetTemplateException
Expand Down Expand Up @@ -132,7 +132,8 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
}

@classmethod
def get_allow_cost_estimate(cls, version: Optional[str] = None) -> bool:
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
version = extra.get("version")
return version is not None and StrictVersion(version) >= StrictVersion("0.319")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one liner:
return extra.get("version") and StrictVersion(version) >= StrictVersion("0.319")


@classmethod
Expand Down Expand Up @@ -484,7 +485,7 @@ def select_star( # pylint: disable=too-many-arguments

@classmethod
def estimate_statement_cost( # pylint: disable=too-many-locals
cls, statement: str, database: "Database", cursor: Any, user_name: str
cls, statement: str, cursor: Any
) -> Dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
Expand All @@ -495,14 +496,7 @@ def estimate_statement_cost( # pylint: disable=too-many-locals
:param username: Effective username
:return: JSON response from Presto
"""
parsed_query = ParsedQuery(statement)
sql = parsed_query.stripped()

sql_query_mutator = config["SQL_QUERY_MUTATOR"]
if sql_query_mutator:
sql = sql_query_mutator(sql, user_name, security_manager, database)

sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {sql}"
sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}"
cursor.execute(sql)

# the output from Presto is a single column and a single row containing
Expand Down
7 changes: 2 additions & 5 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,11 @@ def function_names(self) -> List[str]:

@property
def allows_cost_estimate(self) -> bool:
extra = self.get_extra()

database_version = extra.get("version")
extra = self.get_extra() or {}
cost_estimate_enabled: bool = extra.get("cost_estimate_enabled") # type: ignore

return (
self.db_engine_spec.get_allow_cost_estimate(database_version)
and cost_estimate_enabled
self.db_engine_spec.get_allow_cost_estimate(extra) and cost_estimate_enabled
)

@property
Expand Down