diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 7f84802080da..55ec2d147922 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -25,6 +25,7 @@ from marshmallow import fields, Schema from marshmallow.validate import Range from sqlalchemy import types +from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL @@ -72,19 +73,32 @@ def process(value: Any) -> str: def monkeypatch_dialect() -> None: """ - Monkeypatch dialect to correctly escape single quotes. + Monkeypatch dialect to correctly escape single quotes for Databricks. - The Databricks SQLAlchemy dialect we currently use does not escape single quotes - correctly -- it doubles the single quotes, instead of adding a backslash. The fixed - version requires SQLAlchemy 2.0, which is not yet available in Superset. + The Databricks SQLAlchemy dialect (<3.0) incorrectly escapes single quotes by + doubling them ('O''Hara') instead of using backslash escaping ('O\'Hara'). The + fixed version requires SQLAlchemy>=2.0, which is not yet compatible with Superset. + + Since the DatabricksDialect.colspecs points to the base class (HiveDialect.colspecs) + we can't patch it without affecting other Hive-based dialects. The solution is to + introduce a dialect-aware string type so that the change applies only to Databricks. """ try: - from sqlalchemy_databricks._dialect import DatabricksDialect + from pyhive.sqlalchemy_hive import HiveDialect + + class ContextAwareStringType(types.TypeDecorator): + impl = types.String + cache_ok = True + + def literal_processor( + self, dialect: DefaultDialect + ) -> Callable[[Any], str]: + if dialect.__class__.__name__ == "DatabricksDialect": + return DatabricksStringType().literal_processor(dialect) + return super().literal_processor(dialect) + + HiveDialect.colspecs[types.String] = ContextAwareStringType - # copy because the dictionary is shared with other subclasses if it hasn't been - # overwritten - DatabricksDialect.colspecs = DatabricksDialect.colspecs.copy() - DatabricksDialect.colspecs[types.String] = DatabricksStringType except ImportError: pass @@ -646,5 +660,5 @@ def adjust_engine_params( return uri, connect_args -# remove once we've upgraded to SQLAlchemy 2.0 and the 2.x databricks-sqlalchemy lib +# TODO: remove once we've upgraded to SQLAlchemy>=2.0 and databricks-sql-python>=3.x monkeypatch_dialect()