diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 66e9078bee83..a885684408f8 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -17,13 +17,14 @@ from __future__ import annotations from datetime import datetime -from typing import Any, TYPE_CHECKING, TypedDict, Union +from typing import Any, Callable, TYPE_CHECKING, TypedDict, Union from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.validate import Range +from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL @@ -40,6 +41,51 @@ from superset.models.core import Database +try: + from databricks.sql.utils import ParamEscaper +except ImportError: + + class ParamEscaper: # type: ignore + """Dummy class.""" + + +class DatabricksStringType(types.TypeDecorator): + impl = types.String + cache_ok = True + pe = ParamEscaper() + + def process_literal_param(self, value: Any, dialect: Any) -> str: + return self.pe.escape_string(value) + + def literal_processor(self, dialect: Any) -> Callable[[Any], str]: + def process(value: Any) -> str: + _step1 = self.process_literal_param(value, dialect="databricks") + if dialect.identifier_preparer._double_percents: + _step2 = _step1.replace("%", "%%") + else: + _step2 = _step1 + + return "%s" % _step2 + + return process + + +def monkeypatch_dialect() -> None: + """ + Monkeypatch dialect to correctly escape single quotes. + + The Databricks SQLAlchemy dialect we currently use does not escape single quotes + correctly -- it doubles the single quotes, instead of adding a backslash. The fixed + version requires SQLAlchemy 2.0, which is not yet available in Superset. + """ + try: + from sqlalchemy_databricks._dialect import DatabricksDialect + + DatabricksDialect.colspecs[types.String] = DatabricksStringType + except ImportError: + pass + + class DatabricksBaseSchema(Schema): """ Fields that are required for both Databricks drivers that uses a @@ -595,3 +641,7 @@ def adjust_engine_params( uri = uri.update_query_dict({"schema": schema}) return uri, connect_args + + +# remove once we've upgraded to SQLAlchemy 2.0 and the 2.x databricks-sqlalchemy lib +monkeypatch_dialect()