Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
from __future__ import annotations

from datetime import datetime
from typing import Any, TYPE_CHECKING, TypedDict, Union
from typing import Any, Callable, TYPE_CHECKING, TypedDict, Union

from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.validate import Range
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL

Expand All @@ -40,6 +41,51 @@
from superset.models.core import Database


try:
from databricks.sql.utils import ParamEscaper
except ImportError:

class ParamEscaper: # type: ignore
"""Dummy class."""


class DatabricksStringType(types.TypeDecorator):
impl = types.String
cache_ok = True
pe = ParamEscaper()

def process_literal_param(self, value: Any, dialect: Any) -> str:
return self.pe.escape_string(value)

def literal_processor(self, dialect: Any) -> Callable[[Any], str]:
def process(value: Any) -> str:
_step1 = self.process_literal_param(value, dialect="databricks")
if dialect.identifier_preparer._double_percents:
_step2 = _step1.replace("%", "%%")
else:
_step2 = _step1

return "%s" % _step2

return process


def monkeypatch_dialect() -> None:
"""
Monkeypatch dialect to correctly escape single quotes.

The Databricks SQLAlchemy dialect we currently use does not escape single quotes
correctly -- it doubles the single quotes, instead of adding a backslash. The fixed
version requires SQLAlchemy 2.0, which is not yet available in Superset.
"""
try:
from sqlalchemy_databricks._dialect import DatabricksDialect

DatabricksDialect.colspecs[types.String] = DatabricksStringType
except ImportError:
pass


class DatabricksBaseSchema(Schema):
"""
Fields that are required for both Databricks drivers that uses a
Expand Down Expand Up @@ -595,3 +641,7 @@ def adjust_engine_params(
uri = uri.update_query_dict({"schema": schema})

return uri, connect_args


# remove once we've upgraded to SQLAlchemy 2.0 and the 2.x databricks-sqlalchemy lib
monkeypatch_dialect()
Loading