Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from marshmallow import fields, Schema
from marshmallow.validate import Range
from sqlalchemy import types
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL

Expand Down Expand Up @@ -72,19 +73,32 @@ def process(value: Any) -> str:

def monkeypatch_dialect() -> None:
"""
Monkeypatch dialect to correctly escape single quotes.
Monkeypatch dialect to correctly escape single quotes for Databricks.

The Databricks SQLAlchemy dialect we currently use does not escape single quotes
correctly -- it doubles the single quotes, instead of adding a backslash. The fixed
version requires SQLAlchemy 2.0, which is not yet available in Superset.
The Databricks SQLAlchemy dialect (<3.0) incorrectly escapes single quotes by
doubling them ('O''Hara') instead of using backslash escaping ('O\'Hara'). The
fixed version requires SQLAlchemy>=2.0, which is not yet compatible with Superset.

Since the DatabricksDialect.colspecs points to the base class (HiveDialect.colspecs)
we can't patch it without affecting other Hive-based dialects. The solution is to
introduce a dialect-aware string type so that the change applies only to Databricks.
"""
try:
from sqlalchemy_databricks._dialect import DatabricksDialect
from pyhive.sqlalchemy_hive import HiveDialect

class ContextAwareStringType(types.TypeDecorator):
impl = types.String
cache_ok = True

def literal_processor(
self, dialect: DefaultDialect
) -> Callable[[Any], str]:
if dialect.__class__.__name__ == "DatabricksDialect":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small suggestion, leave the import from line 82 and change this to:

Suggested change
if dialect.__class__.__name__ == "DatabricksDialect":
if dialect.__class__ is DatabricksDialect:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried that (also with isclass() and issubclass()) but couldn't make it work. Will keep as is for now

return DatabricksStringType().literal_processor(dialect)
return super().literal_processor(dialect)

HiveDialect.colspecs[types.String] = ContextAwareStringType

# copy because the dictionary is shared with other subclasses if it hasn't been
# overwritten
DatabricksDialect.colspecs = DatabricksDialect.colspecs.copy()
DatabricksDialect.colspecs[types.String] = DatabricksStringType
except ImportError:
pass

Expand Down Expand Up @@ -646,5 +660,5 @@ def adjust_engine_params(
return uri, connect_args


# remove once we've upgraded to SQLAlchemy 2.0 and the 2.x databricks-sqlalchemy lib
# TODO: remove once we've upgraded to SQLAlchemy>=2.0 and databricks-sql-python>=3.x
monkeypatch_dialect()
Loading