Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ibis-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION: ${{ secrets.AWS_REGION }}
AWS_S3_BUCKET: ${{ secrets.AWS_S3_BUCKET }}
run: poetry run pytest -m "not bigquery and not snowflake and not canner and not s3_file and not gcs_file and not athena and not redshift"
run: poetry run pytest -m "not bigquery and not snowflake and not canner and not s3_file and not gcs_file and not athena and not redshift and not databricks"
- name: Test bigquery if need
if: contains(github.event.pull_request.labels.*.name, 'bigquery')
env:
Expand Down
23 changes: 23 additions & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ class QuerySnowflakeDTO(QueryDTO):
connection_info: SnowflakeConnectionInfo = connection_info_field


class QueryDatabricksDTO(QueryDTO):
connection_info: DatabricksConnectionInfo = connection_info_field


class QueryTrinoDTO(QueryDTO):
connection_info: ConnectionUrl | TrinoConnectionInfo = connection_info_field

Expand Down Expand Up @@ -360,6 +364,24 @@ class SnowflakeConnectionInfo(BaseConnectionInfo):
)


class DatabricksConnectionInfo(BaseConnectionInfo):
server_hostname: SecretStr = Field(
alias="serverHostname",
description="the server hostname of your Databricks instance",
examples=["dbc-xxxxxxxx-xxxx.cloud.databricks.com"],
)
http_path: SecretStr = Field(
alias="httpPath",
description="the HTTP path of your Databricks SQL warehouse",
examples=["/sql/1.0/warehouses/xxxxxxxx"],
)
access_token: SecretStr = Field(
alias="accessToken",
description="the access token for your Databricks instance",
examples=["XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"],
)


class TrinoConnectionInfo(BaseConnectionInfo):
host: SecretStr = Field(
description="the hostname of your database", examples=["localhost"]
Expand Down Expand Up @@ -485,6 +507,7 @@ class GcsFileConnectionInfo(BaseConnectionInfo):
| RedshiftConnectionInfo
| RedshiftIAMConnectionInfo
| SnowflakeConnectionInfo
| DatabricksConnectionInfo
| TrinoConnectionInfo
| LocalFileConnectionInfo
| S3FileConnectionInfo
Expand Down
14 changes: 14 additions & 0 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ClickHouseConnectionInfo,
ConnectionInfo,
ConnectionUrl,
DatabricksConnectionInfo,
GcsFileConnectionInfo,
LocalFileConnectionInfo,
MinioFileConnectionInfo,
Expand All @@ -31,6 +32,7 @@
QueryBigQueryDTO,
QueryCannerDTO,
QueryClickHouseDTO,
QueryDatabricksDTO,
QueryDTO,
QueryGcsFileDTO,
QueryLocalFileDTO,
Expand Down Expand Up @@ -71,6 +73,7 @@ class DataSource(StrEnum):
s3_file = auto()
minio_file = auto()
gcs_file = auto()
databricks = auto()

def get_connection(self, info: ConnectionInfo) -> BaseBackend:
try:
Expand Down Expand Up @@ -176,6 +179,8 @@ def _build_connection_info(self, data: dict) -> ConnectionInfo:
return MinioFileConnectionInfo.model_validate(data)
case DataSource.gcs_file:
return GcsFileConnectionInfo.model_validate(data)
case DataSource.databricks:
return DatabricksConnectionInfo.model_validate(data)
case _:
raise NotImplementedError(f"Unsupported data source: {self}")

Expand Down Expand Up @@ -225,6 +230,7 @@ class DataSourceExtension(Enum):
s3_file = QueryS3FileDTO
minio_file = QueryMinioFileDTO
gcs_file = QueryGcsFileDTO
databricks = QueryDatabricksDTO

def __init__(self, dto: QueryDTO):
self.dto = dto
Expand Down Expand Up @@ -408,6 +414,14 @@ def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend:
**info.kwargs if info.kwargs else dict(),
)

@staticmethod
def get_databricks_connection(info: DatabricksConnectionInfo) -> BaseBackend:
return ibis.databricks.connect(
server_hostname=info.server_hostname.get_secret_value(),
http_path=info.http_path.get_secret_value(),
access_token=info.access_token.get_secret_value(),
)

@staticmethod
def _create_ssl_context(info: ConnectionInfo) -> ssl.SSLContext | None:
ssl_mode = (
Expand Down
187 changes: 187 additions & 0 deletions ibis-server/app/model/metadata/databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from loguru import logger

from app.model import DatabricksConnectionInfo
from app.model.data_source import DataSource
from app.model.metadata.dto import (
Column,
Constraint,
ConstraintType,
RustWrenEngineColumnType,
Table,
TableProperties,
)
from app.model.metadata.metadata import Metadata

# https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-datatypes
DATABRICKS_TYPE_MAPPING = {
"bigint": RustWrenEngineColumnType.BIGINT,
"binary": RustWrenEngineColumnType.BYTEA,
"boolean": RustWrenEngineColumnType.BOOL,
"date": RustWrenEngineColumnType.DATE,
"decimal": RustWrenEngineColumnType.DECIMAL,
"double": RustWrenEngineColumnType.DOUBLE,
"float": RustWrenEngineColumnType.FLOAT,
"int": RustWrenEngineColumnType.INTEGER,
"smallint": RustWrenEngineColumnType.SMALLINT,
"string": RustWrenEngineColumnType.STRING,
"timestamp": RustWrenEngineColumnType.TIMESTAMP,
"timestamp_ntz": RustWrenEngineColumnType.TIMESTAMP,
"tinyint": RustWrenEngineColumnType.TINYINT,
"variant": RustWrenEngineColumnType.VARIANT,
"object": RustWrenEngineColumnType.JSON,
}


class DatabricksMetadata(Metadata):
def __init__(self, connection_info: DatabricksConnectionInfo):
super().__init__(connection_info)
self.connection = DataSource.databricks.get_connection(connection_info)

def get_table_list(self) -> list[Table]:
sql = """
SELECT
c.TABLE_CATALOG AS TABLE_CATALOG,
c.TABLE_SCHEMA AS TABLE_SCHEMA,
c.TABLE_NAME AS TABLE_NAME,
c.COLUMN_NAME AS COLUMN_NAME,
c.DATA_TYPE AS DATA_TYPE,
c.IS_NULLABLE AS IS_NULLABLE,
c.COMMENT AS COLUMN_COMMENT,
t.COMMENT AS TABLE_COMMENT
FROM
INFORMATION_SCHEMA.COLUMNS c
JOIN
INFORMATION_SCHEMA.TABLES t
ON c.TABLE_SCHEMA = t.TABLE_SCHEMA
AND c.TABLE_NAME = t.TABLE_NAME
AND c.TABLE_CATALOG = t.TABLE_CATALOG
WHERE
c.TABLE_SCHEMA NOT IN ('information_schema')
"""
response = self.connection.sql(sql).to_pandas().to_dict(orient="records")

unique_tables = {}
for row in response:
# generate unique table name
schema_table = self._format_compact_table_name(
row["TABLE_CATALOG"], row["TABLE_SCHEMA"], row["TABLE_NAME"]
)
# init table if not exists
if schema_table not in unique_tables:
unique_tables[schema_table] = Table(
name=schema_table,
description=row["TABLE_COMMENT"],
columns=[],
properties=TableProperties(
schema=row["TABLE_SCHEMA"],
catalog=row["TABLE_CATALOG"],
table=row["TABLE_NAME"],
),
primaryKey="",
)

# table exists, and add column to the table
data_type = row["DATA_TYPE"].lower()
if data_type.startswith(("array", "map", "struct")):
col_type = data_type
else:
col_type = self._transform_column_type(row["DATA_TYPE"])

unique_tables[schema_table].columns.append(
Column(
name=row["COLUMN_NAME"],
type=col_type,
notNull=row["IS_NULLABLE"].lower() == "no",
description=row["COLUMN_COMMENT"],
properties=None,
)
)
return list(unique_tables.values())

def get_constraints(self) -> list[Constraint]:
sql = """
SELECT
tc.table_catalog,
tc.table_schema,
tc.table_name,
kcu.column_name,
ccu.table_catalog AS foreign_table_catalog,
ccu.table_schema AS foreign_table_schema,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.constraint_schema = kcu.constraint_schema
AND tc.table_catalog = kcu.table_catalog
AND tc.table_schema = kcu.table_schema
AND tc.table_name = kcu.table_name
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.constraint_catalog = tc.constraint_catalog
AND ccu.constraint_schema = tc.constraint_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
"""
res = self.connection.sql(sql).to_pandas().to_dict(orient="records")
constraints = []
for row in res:
constraints.append(
Constraint(
constraintName=self._format_constraint_name(
row["table_name"],
row["column_name"],
row["foreign_table_name"],
row["foreign_column_name"],
),
constraintTable=self._format_compact_table_name(
row["table_catalog"], row["table_schema"], row["table_name"]
),
constraintColumn=row["column_name"],
constraintedTable=self._format_compact_table_name(
row["foreign_table_catalog"],
row["foreign_table_schema"],
row["foreign_table_name"],
),
constraintedColumn=row["foreign_column_name"],
constraintType=ConstraintType.FOREIGN_KEY,
)
)
return constraints

def get_version(self) -> str:
return (
self.connection.sql("SELECT current_version().dbsql_version")
.to_pandas()
.iloc[0, 0]
)

def _format_constraint_name(
self, table_name, column_name, foreign_table_name, foreign_column_name
):
return f"{table_name}_{column_name}_{foreign_table_name}_{foreign_column_name}"

def _format_compact_table_name(self, catalog: str, schema: str, table: str):
return f"{catalog}.{schema}.{table}"

def _transform_column_type(self, data_type: str) -> RustWrenEngineColumnType:
# Convert to lowercase for comparison
normalized_type = data_type.lower()

if normalized_type.startswith("decimal"):
return RustWrenEngineColumnType.DECIMAL

if normalized_type.startswith("geography"):
return RustWrenEngineColumnType.GEOGRAPHY

if normalized_type.startswith("geometry"):
return RustWrenEngineColumnType.GEOMETRY

# Use the module-level mapping table
mapped_type = DATABRICKS_TYPE_MAPPING.get(
normalized_type, RustWrenEngineColumnType.UNKNOWN
)

if mapped_type == RustWrenEngineColumnType.UNKNOWN:
logger.warning(f"Unknown Databricks data type: {data_type}")

return mapped_type
2 changes: 2 additions & 0 deletions ibis-server/app/model/metadata/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from app.model.metadata.bigquery import BigQueryMetadata
from app.model.metadata.canner import CannerMetadata
from app.model.metadata.clickhouse import ClickHouseMetadata
from app.model.metadata.databricks import DatabricksMetadata
from app.model.metadata.metadata import Metadata
from app.model.metadata.mssql import MSSQLMetadata
from app.model.metadata.mysql import MySQLMetadata
Expand Down Expand Up @@ -35,6 +36,7 @@
DataSource.s3_file: S3FileMetadata,
DataSource.minio_file: MinioFileMetadata,
DataSource.gcs_file: GcsFileMetadata,
DataSource.databricks: DatabricksMetadata,
}


Expand Down
Loading