Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class QuerySnowflakeDTO(QueryDTO):


class QueryDatabricksDTO(QueryDTO):
connection_info: DatabricksConnectionInfo = connection_info_field
connection_info: DatabricksConnectionUnion = connection_info_field


class QueryTrinoDTO(QueryDTO):
Expand Down Expand Up @@ -364,7 +364,8 @@ class SnowflakeConnectionInfo(BaseConnectionInfo):
)


class DatabricksConnectionInfo(BaseConnectionInfo):
class DatabricksTokenConnectionInfo(BaseConnectionInfo):
databricks_type: Literal["token"] = "token"
server_hostname: SecretStr = Field(
alias="serverHostname",
description="the server hostname of your Databricks instance",
Expand All @@ -382,6 +383,43 @@ class DatabricksConnectionInfo(BaseConnectionInfo):
)


# https://docs.databricks.com/aws/en/dev-tools/python-sql-connector#oauth-machine-to-machine-m2m-authentication
class DatabricksServicePrincipalConnectionInfo(BaseConnectionInfo):
databricks_type: Literal["service_principal"] = "service_principal"
server_hostname: SecretStr = Field(
alias="serverHostname",
description="the server hostname of your Databricks instance",
examples=["dbc-xxxxxxxx-xxxx.cloud.databricks.com"],
)
http_path: SecretStr = Field(
alias="httpPath",
description="the HTTP path of your Databricks SQL warehouse",
examples=["/sql/1.0/warehouses/xxxxxxxx"],
)
client_id: SecretStr = Field(
alias="clientId",
description="the client ID for OAuth M2M authentication",
examples=["your-client-id"],
)
client_secret: SecretStr = Field(
alias="clientSecret",
description="the client secret for OAuth M2M authentication",
examples=["your-client-secret"],
)
azure_tenant_id: SecretStr | None = Field(
alias="azureTenantId",
description="the Azure tenant ID for OAuth M2M authentication",
examples=["your-tenant-id"],
default=None,
)


DatabricksConnectionUnion = Annotated[
Union[DatabricksTokenConnectionInfo, DatabricksServicePrincipalConnectionInfo],
Field(discriminator="databricks_type"),
]


class TrinoConnectionInfo(BaseConnectionInfo):
host: SecretStr = Field(
description="the hostname of your database", examples=["localhost"]
Expand Down Expand Up @@ -507,7 +545,7 @@ class GcsFileConnectionInfo(BaseConnectionInfo):
| RedshiftConnectionInfo
| RedshiftIAMConnectionInfo
| SnowflakeConnectionInfo
| DatabricksConnectionInfo
| DatabricksTokenConnectionInfo
| TrinoConnectionInfo
| LocalFileConnectionInfo
| S3FileConnectionInfo
Expand Down
59 changes: 59 additions & 0 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class ClickHouseDbError(Exception):
import pyarrow as pa
import sqlglot.expressions as sge
import trino
from databricks import sql as dbsql
from databricks.sdk.core import Config as DbConfig
from databricks.sdk.core import oauth_service_principal
from duckdb import HTTPException, IOException
from google.cloud import bigquery
from google.oauth2 import service_account
Expand All @@ -40,6 +43,9 @@ class ClickHouseDbError(Exception):

from app.model import (
ConnectionInfo,
DatabricksConnectionUnion,
DatabricksServicePrincipalConnectionInfo,
DatabricksTokenConnectionInfo,
GcsFileConnectionInfo,
MinioFileConnectionInfo,
RedshiftConnectionInfo,
Expand Down Expand Up @@ -88,6 +94,8 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo):
self._connector = RedshiftConnector(connection_info)
elif data_source == DataSource.postgres:
self._connector = PostgresConnector(connection_info)
elif data_source == DataSource.databricks:
self._connector = DatabricksConnector(connection_info)
else:
self._connector = SimpleConnector(data_source, connection_info)

Expand Down Expand Up @@ -605,3 +613,54 @@ def close(self) -> None:
self.connection.close()
except Exception as e:
logger.warning(f"Error closing Redshift connection: {e}")


class DatabricksConnector(SimpleConnector):
def __init__(self, connection_info: DatabricksConnectionUnion):
if isinstance(connection_info, DatabricksTokenConnectionInfo):
self.connection = dbsql.connect(
server_hostname=connection_info.server_hostname.get_secret_value(),
http_path=connection_info.http_path.get_secret_value(),
access_token=connection_info.access_token.get_secret_value(),
)
elif isinstance(connection_info, DatabricksServicePrincipalConnectionInfo):
kwargs = {
"host": connection_info.server_hostname.get_secret_value(),
"client_id": connection_info.client_id.get_secret_value(),
"client_secret": connection_info.client_secret.get_secret_value(),
}
if connection_info.azure_tenant_id is not None:
kwargs["azure_tenant_id"] = (
connection_info.azure_tenant_id.get_secret_value()
)

def credential_provider():
return oauth_service_principal(DbConfig(**kwargs))

self.connection = dbsql.connect(
server_hostname=connection_info.server_hostname.get_secret_value(),
http_path=connection_info.http_path.get_secret_value(),
credentials_provider=credential_provider,
)

def query(self, sql, limit=None):
with closing(self.connection.cursor()) as cursor:
cursor.execute(sql)

if limit is not None:
arrow_table = cursor.fetchmany_arrow(limit)
else:
arrow_table = cursor.fetchall_arrow()

return arrow_table

def dry_run(self, sql):
with closing(self.connection.cursor()) as cursor:
cursor.execute(f"SELECT * FROM ({sql}) AS sub LIMIT 0")

def close(self) -> None:
"""Close the Databricks connection."""
try:
self.connection.close()
except Exception as e:
logger.warning(f"Error closing Databricks connection: {e}")
12 changes: 9 additions & 3 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
ClickHouseConnectionInfo,
ConnectionInfo,
ConnectionUrl,
DatabricksConnectionInfo,
DatabricksServicePrincipalConnectionInfo,
DatabricksTokenConnectionInfo,
GcsFileConnectionInfo,
LocalFileConnectionInfo,
MinioFileConnectionInfo,
Expand Down Expand Up @@ -180,7 +181,12 @@ def _build_connection_info(self, data: dict) -> ConnectionInfo:
case DataSource.gcs_file:
return GcsFileConnectionInfo.model_validate(data)
case DataSource.databricks:
return DatabricksConnectionInfo.model_validate(data)
if (
"databricks_type" in data
and data["databricks_type"] == "service_principal"
):
return DatabricksServicePrincipalConnectionInfo.model_validate(data)
return DatabricksTokenConnectionInfo.model_validate(data)
case _:
raise NotImplementedError(f"Unsupported data source: {self}")

Expand Down Expand Up @@ -415,7 +421,7 @@ def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend:
)

@staticmethod
def get_databricks_connection(info: DatabricksConnectionInfo) -> BaseBackend:
def get_databricks_connection(info: DatabricksTokenConnectionInfo) -> BaseBackend:
return ibis.databricks.connect(
server_hostname=info.server_hostname.get_secret_value(),
http_path=info.http_path.get_secret_value(),
Expand Down
14 changes: 7 additions & 7 deletions ibis-server/app/model/metadata/databricks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from loguru import logger

from app.model import DatabricksConnectionInfo
from app.model.data_source import DataSource
from app.model import DatabricksTokenConnectionInfo
from app.model.connector import DatabricksConnector
from app.model.metadata.dto import (
Column,
Constraint,
Expand Down Expand Up @@ -33,9 +33,9 @@


class DatabricksMetadata(Metadata):
def __init__(self, connection_info: DatabricksConnectionInfo):
def __init__(self, connection_info: DatabricksTokenConnectionInfo):
super().__init__(connection_info)
self.connection = DataSource.databricks.get_connection(connection_info)
self.connection = DatabricksConnector(connection_info)

def get_table_list(self) -> list[Table]:
sql = """
Expand All @@ -58,7 +58,7 @@ def get_table_list(self) -> list[Table]:
WHERE
c.TABLE_SCHEMA NOT IN ('information_schema')
"""
response = self.connection.sql(sql).to_pandas().to_dict(orient="records")
response = self.connection.query(sql).to_pandas().to_dict(orient="records")

unique_tables = {}
for row in response:
Expand Down Expand Up @@ -122,7 +122,7 @@ def get_constraints(self) -> list[Constraint]:
AND ccu.constraint_schema = tc.constraint_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
"""
res = self.connection.sql(sql).to_pandas().to_dict(orient="records")
res = self.connection.query(sql).to_pandas().to_dict(orient="records")
constraints = []
for row in res:
constraints.append(
Expand Down Expand Up @@ -150,7 +150,7 @@ def get_constraints(self) -> list[Constraint]:

def get_version(self) -> str:
return (
self.connection.sql("SELECT current_version().dbsql_version")
self.connection.query("SELECT current_version().dbsql_version")
.to_pandas()
.iloc[0, 0]
)
Expand Down
24 changes: 23 additions & 1 deletion ibis-server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions ibis-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ redshift_connector = "2.1.7"
datafusion = "^47.0.0, <49.0.0"
starlette = "^0.49.1"
databricks-sql-connector = { version = "^4.0.1", extras = ["pyarrow"] }
databricks-sdk = "^0.73.0"

[tool.poetry.group.jupyter]
optional = true
Expand Down
18 changes: 15 additions & 3 deletions ibis-server/tests/routers/v3/connector/databricks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,19 @@ def init_databricks(connection_info):
@pytest.fixture(scope="module")
def connection_info() -> dict[str, str]:
return {
"serverHostname": os.getenv("DATABRICKS_SERVER_HOSTNAME"),
"httpPath": os.getenv("DATABRICKS_HTTP_PATH"),
"accessToken": os.getenv("DATABRICKS_TOKEN"),
"databricks_type": "token",
"serverHostname": os.getenv("TEST_DATABRICKS_SERVER_HOSTNAME"),
"httpPath": os.getenv("TEST_DATABRICKS_HTTP_PATH"),
"accessToken": os.getenv("TEST_DATABRICKS_TOKEN"),
}


@pytest.fixture(scope="module")
def service_principal_connection_info() -> dict[str, str]:
return {
"databricks_type": "service_principal",
"serverHostname": os.getenv("TEST_DATABRICKS_SERVER_HOSTNAME"),
"httpPath": os.getenv("TEST_DATABRICKS_HTTP_PATH"),
"clientId": os.getenv("TEST_DATABRICKS_CLIENT_ID"),
"clientSecret": os.getenv("TEST_DATABRICKS_CLIENT_SECRET"),
}
26 changes: 22 additions & 4 deletions ibis-server/tests/routers/v3/connector/databricks/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
],
},
],
"dataSource": "databricks",
}


Expand Down Expand Up @@ -86,15 +87,32 @@ async def test_query(client, manifest_str, connection_info):
"orderkey": "int64",
"custkey": "int64",
"orderstatus": "string",
"totalprice": "decimal128(38, 9)",
"totalprice": "decimal128(18, 2)",
"orderdate": "date32[day]",
"order_cust_key": "string",
"timestamp": "timestamp[us, tz=UTC]",
"timestamptz": "timestamp[us, tz=UTC]",
"test_null_time": "timestamp[us, tz=UTC]",
"timestamp": "timestamp[us, tz=Etc/UTC]",
"timestamptz": "timestamp[us, tz=Etc/UTC]",
"test_null_time": "timestamp[us, tz=Etc/UTC]",
}


async def test_query_with_service_principal(
client, manifest_str, service_principal_connection_info
):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": service_principal_connection_info,
"manifestStr": manifest_str,
"sql": "SELECT * FROM wren.public.orders ORDER BY orderkey LIMIT 1",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == len(manifest["models"][0]["columns"])
assert len(result["data"]) == 1


async def test_query_with_limit(client, manifest_str, connection_info):
response = await client.post(
url=f"{base_url}/query",
Expand Down