Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import importlib
import os
import time
from abc import ABC, abstractmethod
from contextlib import closing, suppress
from decimal import Decimal as PyDecimal
from functools import cache
Expand Down Expand Up @@ -97,7 +98,7 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo):
elif data_source == DataSource.databricks:
self._connector = DatabricksConnector(connection_info)
else:
self._connector = SimpleConnector(data_source, connection_info)
self._connector = IbisConnector(data_source, connection_info)

def query(self, sql: str, limit: int | None = None) -> pa.Table:
try:
Expand Down Expand Up @@ -179,7 +180,21 @@ def close(self) -> None:
)


class SimpleConnector:
class ConnectorABC(ABC):
@abstractmethod
def query(self, sql: str, limit: int | None = None) -> pa.Table:
pass

@abstractmethod
def dry_run(self, sql: str) -> None:
pass

@abstractmethod
def close(self) -> None:
pass


class IbisConnector(ConnectorABC):
def __init__(self, data_source: DataSource, connection_info: ConnectionInfo):
self.data_source = data_source
self.connection = self.data_source.get_connection(connection_info)
Expand Down Expand Up @@ -258,7 +273,7 @@ def close(self) -> None:
self.connection = None


class PostgresConnector(SimpleConnector):
class PostgresConnector(IbisConnector):
def __init__(self, connection_info):
super().__init__(DataSource.postgres, connection_info)

Expand Down Expand Up @@ -302,7 +317,7 @@ def close(self) -> None:
self.connection = None


class MSSqlConnector(SimpleConnector):
class MSSqlConnector(IbisConnector):
def __init__(self, connection_info: ConnectionInfo):
super().__init__(DataSource.mssql, connection_info)

Expand Down Expand Up @@ -369,7 +384,7 @@ def _describe_sql_for_error_message(self, sql: str) -> str:
return rows[0][0]


class CannerConnector:
class CannerConnector(IbisConnector):
def __init__(self, connection_info: ConnectionInfo):
self.connection = DataSource.canner.get_connection(connection_info)

Expand Down Expand Up @@ -432,13 +447,10 @@ def _to_ibis_type(type_name: str) -> dt.DataType:
return postgres_compiler.type_mapper.from_string(type_name)


class BigQueryConnector(SimpleConnector):
class BigQueryConnector(ConnectorABC):
def __init__(self, connection_info: ConnectionInfo):
super().__init__(DataSource.bigquery, connection_info)
self.data_source = DataSource.bigquery
self.connection_info = connection_info

@tracer.start_as_current_span("connector_query", kind=trace.SpanKind.CLIENT)
def query(self, sql: str, limit: int | None = None) -> pa.Table:
credits_json = loads(
base64.b64decode(
self.connection_info.credentials.get_secret_value()
Expand All @@ -454,10 +466,28 @@ def query(self, sql: str, limit: int | None = None) -> pa.Table:
]
)
client = bigquery.Client(credentials=credentials)
return client.query(sql).result(max_results=limit).to_arrow()
self.connection = client

@tracer.start_as_current_span("connector_query", kind=trace.SpanKind.CLIENT)
def query(self, sql: str, limit: int | None = None) -> pa.Table:
return self.connection.query(sql).result(max_results=limit).to_arrow()

@tracer.start_as_current_span("connector_dry_run", kind=trace.SpanKind.CLIENT)
def dry_run(self, sql: str) -> None:
self.connection.query(
sql, job_config=bigquery.QueryJobConfig(dry_run=True, use_query_cache=False)
)

@tracer.start_as_current_span("connector_close", kind=trace.SpanKind.CLIENT)
def close(self) -> None:
"""Close the BigQuery connection."""
try:
self.connection.close()
except Exception as e:
logger.warning(f"Error closing BigQuery connection: {e}")


class DuckDBConnector:
class DuckDBConnector(ConnectorABC):
def __init__(self, connection_info: ConnectionInfo):
import duckdb # noqa: PLC0415

Expand Down Expand Up @@ -538,7 +568,7 @@ def close(self) -> None:
logger.warning(f"Error closing DuckDB connection: {e}")


class RedshiftConnector:
class RedshiftConnector(ConnectorABC):
def __init__(self, connection_info: RedshiftConnectionUnion):
import redshift_connector # noqa: PLC0415

Expand Down Expand Up @@ -594,7 +624,7 @@ def close(self) -> None:
logger.warning(f"Error closing Redshift connection: {e}")


class DatabricksConnector(SimpleConnector):
class DatabricksConnector(ConnectorABC):
def __init__(self, connection_info: DatabricksConnectionUnion):
if isinstance(connection_info, DatabricksTokenConnectionInfo):
self.connection = dbsql.connect(
Expand Down
Loading