Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ class BigQueryConnectionInfo(BaseConnectionInfo):
credentials: SecretStr = Field(
description="Base64 encode `credentials.json`", examples=["eyJ..."]
)
job_timeout_ms: int | None = Field(
description="Job timeout in milliseconds. If the job is not complete within the specified time, it will be cancelled.",
default=None,
)


class AthenaConnectionInfo(BaseConnectionInfo):
Expand Down
3 changes: 3 additions & 0 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ def close(self) -> None:
elif hasattr(self.connection, "close"):
# Try to close the connection directly if it has a close method
self.connection.close()
elif hasattr(self.connection, "disconnect"):
# Some backends use disconnect instead of close
self.connection.disconnect()
else:
logger.warning(
f"Closing connection for {self.data_source.value} is not implemented."
Expand Down
16 changes: 12 additions & 4 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from urllib.parse import unquote_plus

import ibis
from google.cloud import bigquery
from google.oauth2 import service_account
from ibis import BaseBackend

Expand Down Expand Up @@ -124,6 +125,10 @@ def get_connection_info(
f"{session_timeout}s"
)
info.kwargs["session_properties"] = session_properties
case DataSource.bigquery:
session_timeout = headers.get(X_WREN_DB_STATEMENT_TIMEOUT, 180)
if not hasattr(info, "job_timeout_ms") or info.job_timeout_ms is None:
info.job_timeout_ms = int(session_timeout) * 1000
return info

def _build_connection_info(self, data: dict) -> ConnectionInfo:
Expand Down Expand Up @@ -265,11 +270,14 @@ def get_bigquery_connection(info: BigQueryConnectionInfo) -> BaseBackend:
"https://www.googleapis.com/auth/cloud-platform",
]
)
return ibis.bigquery.connect(
project_id=info.project_id.get_secret_value(),
dataset_id=info.dataset_id.get_secret_value(),
credentials=credentials,
bq_client = bigquery.Client(
project=info.project_id.get_secret_value(), credentials=credentials
)
job_config = bigquery.QueryJobConfig()
job_config.job_timeout_ms = info.job_timeout_ms
bq_client.default_query_job_config = job_config
backend = ibis.bigquery.connect(client=bq_client, credentials=credentials)
return backend

@staticmethod
def get_canner_connection(info: CannerConnectionInfo) -> BaseBackend:
Expand Down