diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index c8f62d06b..109daffa2 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -68,7 +68,7 @@ class QuerySnowflakeDTO(QueryDTO): class QueryDatabricksDTO(QueryDTO): - connection_info: DatabricksConnectionInfo = connection_info_field + connection_info: DatabricksConnectionUnion = connection_info_field class QueryTrinoDTO(QueryDTO): @@ -364,7 +364,8 @@ class SnowflakeConnectionInfo(BaseConnectionInfo): ) -class DatabricksConnectionInfo(BaseConnectionInfo): +class DatabricksTokenConnectionInfo(BaseConnectionInfo): + databricks_type: Literal["token"] = "token" server_hostname: SecretStr = Field( alias="serverHostname", description="the server hostname of your Databricks instance", @@ -382,6 +383,43 @@ class DatabricksConnectionInfo(BaseConnectionInfo): ) +# https://docs.databricks.com/aws/en/dev-tools/python-sql-connector#oauth-machine-to-machine-m2m-authentication +class DatabricksServicePrincipalConnectionInfo(BaseConnectionInfo): + databricks_type: Literal["service_principal"] = "service_principal" + server_hostname: SecretStr = Field( + alias="serverHostname", + description="the server hostname of your Databricks instance", + examples=["dbc-xxxxxxxx-xxxx.cloud.databricks.com"], + ) + http_path: SecretStr = Field( + alias="httpPath", + description="the HTTP path of your Databricks SQL warehouse", + examples=["/sql/1.0/warehouses/xxxxxxxx"], + ) + client_id: SecretStr = Field( + alias="clientId", + description="the client ID for OAuth M2M authentication", + examples=["your-client-id"], + ) + client_secret: SecretStr = Field( + alias="clientSecret", + description="the client secret for OAuth M2M authentication", + examples=["your-client-secret"], + ) + azure_tenant_id: SecretStr | None = Field( + alias="azureTenantId", + description="the Azure tenant ID for OAuth M2M authentication", + examples=["your-tenant-id"], + default=None, + ) + + +DatabricksConnectionUnion = Annotated[ + Union[DatabricksTokenConnectionInfo, DatabricksServicePrincipalConnectionInfo], + Field(discriminator="databricks_type"), +] + + class TrinoConnectionInfo(BaseConnectionInfo): host: SecretStr = Field( description="the hostname of your database", examples=["localhost"] @@ -507,7 +545,7 @@ class GcsFileConnectionInfo(BaseConnectionInfo): | RedshiftConnectionInfo | RedshiftIAMConnectionInfo | SnowflakeConnectionInfo - | DatabricksConnectionInfo + | DatabricksTokenConnectionInfo | TrinoConnectionInfo | LocalFileConnectionInfo | S3FileConnectionInfo diff --git a/ibis-server/app/model/connector.py b/ibis-server/app/model/connector.py index ab27fb4c7..ba9f7fbc5 100644 --- a/ibis-server/app/model/connector.py +++ b/ibis-server/app/model/connector.py @@ -27,6 +27,9 @@ class ClickHouseDbError(Exception): import pyarrow as pa import sqlglot.expressions as sge import trino +from databricks import sql as dbsql +from databricks.sdk.core import Config as DbConfig +from databricks.sdk.core import oauth_service_principal from duckdb import HTTPException, IOException from google.cloud import bigquery from google.oauth2 import service_account @@ -40,6 +43,9 @@ class ClickHouseDbError(Exception): from app.model import ( ConnectionInfo, + DatabricksConnectionUnion, + DatabricksServicePrincipalConnectionInfo, + DatabricksTokenConnectionInfo, GcsFileConnectionInfo, MinioFileConnectionInfo, RedshiftConnectionInfo, @@ -88,6 +94,8 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo): self._connector = RedshiftConnector(connection_info) elif data_source == DataSource.postgres: self._connector = PostgresConnector(connection_info) + elif data_source == DataSource.databricks: + self._connector = DatabricksConnector(connection_info) else: self._connector = SimpleConnector(data_source, connection_info) @@ -605,3 +613,54 @@ def close(self) -> None: self.connection.close() except Exception as e: logger.warning(f"Error closing Redshift connection: {e}") + + +class DatabricksConnector(SimpleConnector): + def __init__(self, connection_info: DatabricksConnectionUnion): + if isinstance(connection_info, DatabricksTokenConnectionInfo): + self.connection = dbsql.connect( + server_hostname=connection_info.server_hostname.get_secret_value(), + http_path=connection_info.http_path.get_secret_value(), + access_token=connection_info.access_token.get_secret_value(), + ) + elif isinstance(connection_info, DatabricksServicePrincipalConnectionInfo): + kwargs = { + "host": connection_info.server_hostname.get_secret_value(), + "client_id": connection_info.client_id.get_secret_value(), + "client_secret": connection_info.client_secret.get_secret_value(), + } + if connection_info.azure_tenant_id is not None: + kwargs["azure_tenant_id"] = ( + connection_info.azure_tenant_id.get_secret_value() + ) + + def credential_provider(): + return oauth_service_principal(DbConfig(**kwargs)) + + self.connection = dbsql.connect( + server_hostname=connection_info.server_hostname.get_secret_value(), + http_path=connection_info.http_path.get_secret_value(), + credentials_provider=credential_provider, + ) + + def query(self, sql, limit=None): + with closing(self.connection.cursor()) as cursor: + cursor.execute(sql) + + if limit is not None: + arrow_table = cursor.fetchmany_arrow(limit) + else: + arrow_table = cursor.fetchall_arrow() + + return arrow_table + + def dry_run(self, sql): + with closing(self.connection.cursor()) as cursor: + cursor.execute(f"SELECT * FROM ({sql}) AS sub LIMIT 0") + + def close(self) -> None: + """Close the Databricks connection.""" + try: + self.connection.close() + except Exception as e: + logger.warning(f"Error closing Databricks connection: {e}") diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index e26c20cf9..ee944b77b 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -20,7 +20,8 @@ ClickHouseConnectionInfo, ConnectionInfo, ConnectionUrl, - DatabricksConnectionInfo, + DatabricksServicePrincipalConnectionInfo, + DatabricksTokenConnectionInfo, GcsFileConnectionInfo, LocalFileConnectionInfo, MinioFileConnectionInfo, @@ -180,7 +181,12 @@ def _build_connection_info(self, data: dict) -> ConnectionInfo: case DataSource.gcs_file: return GcsFileConnectionInfo.model_validate(data) case DataSource.databricks: - return DatabricksConnectionInfo.model_validate(data) + if ( + "databricks_type" in data + and data["databricks_type"] == "service_principal" + ): + return DatabricksServicePrincipalConnectionInfo.model_validate(data) + return DatabricksTokenConnectionInfo.model_validate(data) case _: raise NotImplementedError(f"Unsupported data source: {self}") @@ -415,7 +421,7 @@ def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend: ) @staticmethod - def get_databricks_connection(info: DatabricksConnectionInfo) -> BaseBackend: + def get_databricks_connection(info: DatabricksTokenConnectionInfo) -> BaseBackend: return ibis.databricks.connect( server_hostname=info.server_hostname.get_secret_value(), http_path=info.http_path.get_secret_value(), diff --git a/ibis-server/app/model/metadata/databricks.py b/ibis-server/app/model/metadata/databricks.py index 0b8548892..32ad645ae 100644 --- a/ibis-server/app/model/metadata/databricks.py +++ b/ibis-server/app/model/metadata/databricks.py @@ -1,7 +1,7 @@ from loguru import logger -from app.model import DatabricksConnectionInfo -from app.model.data_source import DataSource +from app.model import DatabricksTokenConnectionInfo +from app.model.connector import DatabricksConnector from app.model.metadata.dto import ( Column, Constraint, @@ -33,9 +33,9 @@ class DatabricksMetadata(Metadata): - def __init__(self, connection_info: DatabricksConnectionInfo): + def __init__(self, connection_info: DatabricksTokenConnectionInfo): super().__init__(connection_info) - self.connection = DataSource.databricks.get_connection(connection_info) + self.connection = DatabricksConnector(connection_info) def get_table_list(self) -> list[Table]: sql = """ @@ -58,7 +58,7 @@ def get_table_list(self) -> list[Table]: WHERE c.TABLE_SCHEMA NOT IN ('information_schema') """ - response = self.connection.sql(sql).to_pandas().to_dict(orient="records") + response = self.connection.query(sql).to_pandas().to_dict(orient="records") unique_tables = {} for row in response: @@ -122,7 +122,7 @@ def get_constraints(self) -> list[Constraint]: AND ccu.constraint_schema = tc.constraint_schema WHERE tc.constraint_type = 'FOREIGN KEY' """ - res = self.connection.sql(sql).to_pandas().to_dict(orient="records") + res = self.connection.query(sql).to_pandas().to_dict(orient="records") constraints = [] for row in res: constraints.append( @@ -150,7 +150,7 @@ def get_constraints(self) -> list[Constraint]: def get_version(self) -> str: return ( - self.connection.sql("SELECT current_version().dbsql_version") + self.connection.query("SELECT current_version().dbsql_version") .to_pandas() .iloc[0, 0] ) diff --git a/ibis-server/poetry.lock b/ibis-server/poetry.lock index 87d5225ba..85abecd03 100644 --- a/ibis-server/poetry.lock +++ b/ibis-server/poetry.lock @@ -1138,6 +1138,28 @@ files = [ docs = ["ipython", "matplotlib", "numpydoc", "sphinx"] tests = ["pytest", "pytest-cov", "pytest-xdist"] +[[package]] +name = "databricks-sdk" +version = "0.73.0" +description = "Databricks SDK for Python (Beta)" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "databricks_sdk-0.73.0-py3-none-any.whl", hash = "sha256:a4d3cfd19357a2b459d2dc3101454d7f0d1b62865ce099c35d0c342b66ac64ff"}, + {file = "databricks_sdk-0.73.0.tar.gz", hash = "sha256:db09eaaacd98e07dded78d3e7ab47d2f6c886e0380cb577977bd442bace8bd8d"}, +] + +[package.dependencies] +google-auth = ">=2.0,<3.0" +protobuf = ">=4.25.8,<5.26.dev0 || >5.29.0,<5.29.1 || >5.29.1,<5.29.2 || >5.29.2,<5.29.3 || >5.29.3,<5.29.4 || >5.29.4,<6.30.0 || >6.30.0,<6.30.1 || >6.30.1,<6.31.0 || >6.31.0,<7.0" +requests = ">=2.28.1,<3" + +[package.extras] +dev = ["autoflake", "black", "build", "databricks-connect", "httpx", "ipython", "ipywidgets", "isort", "langchain-openai ; python_version > \"3.7\"", "openai", "pycodestyle", "pyfakefs", "pytest", "pytest-cov", "pytest-mock", "pytest-rerunfailures", "pytest-xdist (>=3.6.1,<4.0)", "requests-mock", "wheel"] +notebook = ["ipython (>=8,<10)", "ipywidgets (>=8,<9)"] +openai = ["httpx", "langchain-openai ; python_version > \"3.7\"", "openai"] + [[package]] name = "databricks-sql-connector" version = "4.1.4" @@ -8016,4 +8038,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.12" -content-hash = "a6031c4395bd88833b09e4991d06e0c507d7a208322750ddd63903dd66484cf1" +content-hash = "586d92a9751279dba3cef903e99389af812e353f948ef9efaaf544a0240646f8" diff --git a/ibis-server/pyproject.toml b/ibis-server/pyproject.toml index eab86848f..a4641eab1 100644 --- a/ibis-server/pyproject.toml +++ b/ibis-server/pyproject.toml @@ -50,6 +50,7 @@ redshift_connector = "2.1.7" datafusion = "^47.0.0, <49.0.0" starlette = "^0.49.1" databricks-sql-connector = { version = "^4.0.1", extras = ["pyarrow"] } +databricks-sdk = "^0.73.0" [tool.poetry.group.jupyter] optional = true diff --git a/ibis-server/tests/routers/v3/connector/databricks/conftest.py b/ibis-server/tests/routers/v3/connector/databricks/conftest.py index 8b659e96a..a1b3b5888 100644 --- a/ibis-server/tests/routers/v3/connector/databricks/conftest.py +++ b/ibis-server/tests/routers/v3/connector/databricks/conftest.py @@ -54,7 +54,19 @@ def init_databricks(connection_info): @pytest.fixture(scope="module") def connection_info() -> dict[str, str]: return { - "serverHostname": os.getenv("DATABRICKS_SERVER_HOSTNAME"), - "httpPath": os.getenv("DATABRICKS_HTTP_PATH"), - "accessToken": os.getenv("DATABRICKS_TOKEN"), + "databricks_type": "token", + "serverHostname": os.getenv("TEST_DATABRICKS_SERVER_HOSTNAME"), + "httpPath": os.getenv("TEST_DATABRICKS_HTTP_PATH"), + "accessToken": os.getenv("TEST_DATABRICKS_TOKEN"), + } + + +@pytest.fixture(scope="module") +def service_principal_connection_info() -> dict[str, str]: + return { + "databricks_type": "service_principal", + "serverHostname": os.getenv("TEST_DATABRICKS_SERVER_HOSTNAME"), + "httpPath": os.getenv("TEST_DATABRICKS_HTTP_PATH"), + "clientId": os.getenv("TEST_DATABRICKS_CLIENT_ID"), + "clientSecret": os.getenv("TEST_DATABRICKS_CLIENT_SECRET"), } diff --git a/ibis-server/tests/routers/v3/connector/databricks/test_query.py b/ibis-server/tests/routers/v3/connector/databricks/test_query.py index b3f3dd3c1..9bc2d5764 100644 --- a/ibis-server/tests/routers/v3/connector/databricks/test_query.py +++ b/ibis-server/tests/routers/v3/connector/databricks/test_query.py @@ -49,6 +49,7 @@ ], }, ], + "dataSource": "databricks", } @@ -86,15 +87,32 @@ async def test_query(client, manifest_str, connection_info): "orderkey": "int64", "custkey": "int64", "orderstatus": "string", - "totalprice": "decimal128(38, 9)", + "totalprice": "decimal128(18, 2)", "orderdate": "date32[day]", "order_cust_key": "string", - "timestamp": "timestamp[us, tz=UTC]", - "timestamptz": "timestamp[us, tz=UTC]", - "test_null_time": "timestamp[us, tz=UTC]", + "timestamp": "timestamp[us, tz=Etc/UTC]", + "timestamptz": "timestamp[us, tz=Etc/UTC]", + "test_null_time": "timestamp[us, tz=Etc/UTC]", } +async def test_query_with_service_principal( + client, manifest_str, service_principal_connection_info +): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": service_principal_connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM wren.public.orders ORDER BY orderkey LIMIT 1", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + + async def test_query_with_limit(client, manifest_str, connection_info): response = await client.post( url=f"{base_url}/query",