Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from pydantic import BaseModel, Field, SecretStr

from app.model.error import ErrorCode, WrenError

manifest_str_field = Field(alias="manifestStr", description="Base64 manifest")
connection_info_field = Field(alias="connectionInfo")

Expand All @@ -28,7 +30,7 @@ class QueryDTO(BaseModel):


class QueryBigQueryDTO(QueryDTO):
connection_info: BigQueryConnectionInfo = connection_info_field
connection_info: BigQueryConnectionUnion = connection_info_field


class QueryAthenaDTO(QueryDTO):
Expand Down Expand Up @@ -100,10 +102,6 @@ class ConnectionUrl(BaseConnectionInfo):


class BigQueryConnectionInfo(BaseConnectionInfo):
project_id: SecretStr = Field(description="GCP project id", examples=["my-project"])
dataset_id: SecretStr = Field(
description="BigQuery dataset id", examples=["my_dataset"]
)
credentials: SecretStr = Field(
description="Base64 encode `credentials.json`", examples=["eyJ..."]
)
Expand All @@ -112,6 +110,55 @@ class BigQueryConnectionInfo(BaseConnectionInfo):
default=None,
)

def get_billing_project_id(self) -> str | None:
raise WrenError(
ErrorCode.NOT_IMPLEMENTED,
"get_billing_project_id not implemented by base class",
)


class BigQueryDatasetConnectionInfo(BigQueryConnectionInfo):
bigquery_type: Literal["dataset"] = "dataset"
project_id: SecretStr = Field(description="GCP project id", examples=["my-project"])
dataset_id: SecretStr = Field(
description="BigQuery dataset id", examples=["my_dataset"]
)

def get_billing_project_id(self):
return self.project_id.get_secret_value()

def __hash__(self):
return hash((self.project_id, self.dataset_id, self.credentials))


class BigQueryProjectConnectionInfo(BigQueryConnectionInfo):
bigquery_type: Literal["project"] = "project"
region: SecretStr = Field(
description="the region of your BigQuery connection", examples=["US"]
)
billing_project_id: SecretStr = Field(
description="the billing project id of your BigQuery connection",
examples=["billing-project-1"],
)

def get_billing_project_id(self):
return self.billing_project_id.get_secret_value()

def __hash__(self):
return hash(
(
self.region,
self.billing_project_id,
self.credentials,
)
)


BigQueryConnectionUnion = Annotated[
Union[BigQueryDatasetConnectionInfo, BigQueryProjectConnectionInfo],
Field(discriminator="bigquery_type", default="dataset"),
]


class AthenaConnectionInfo(BaseConnectionInfo):
s3_staging_dir: SecretStr = Field(
Expand Down Expand Up @@ -570,7 +617,8 @@ class GcsFileConnectionInfo(BaseConnectionInfo):

ConnectionInfo = (
AthenaConnectionInfo
| BigQueryConnectionInfo
| BigQueryDatasetConnectionInfo
| BigQueryProjectConnectionInfo
| CannerConnectionInfo
| ClickHouseConnectionInfo
| ConnectionUrl
Expand Down
8 changes: 7 additions & 1 deletion ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,13 @@ def __init__(self, connection_info: ConnectionInfo):
"https://www.googleapis.com/auth/cloud-platform",
]
)
client = bigquery.Client(credentials=credentials)
client = bigquery.Client(
credentials=credentials,
project=connection_info.get_billing_project_id(),
)
job_config = bigquery.QueryJobConfig()
job_config.job_timeout_ms = self.connection_info.job_timeout_ms
client.default_query_job_config = job_config
self.connection = client

@tracer.start_as_current_span("connector_query", kind=trace.SpanKind.CLIENT)
Expand Down
9 changes: 6 additions & 3 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from app.model import (
AthenaConnectionInfo,
BigQueryConnectionInfo,
BigQueryDatasetConnectionInfo,
BigQueryProjectConnectionInfo,
CannerConnectionInfo,
ClickHouseConnectionInfo,
ConnectionInfo,
Expand Down Expand Up @@ -152,7 +153,9 @@ def _build_connection_info(self, data: dict) -> ConnectionInfo:
case DataSource.athena:
return AthenaConnectionInfo.model_validate(data)
case DataSource.bigquery:
return BigQueryConnectionInfo.model_validate(data)
if "bigquery_type" in data and data["bigquery_type"] == "project":
return BigQueryProjectConnectionInfo.model_validate(data)
return BigQueryDatasetConnectionInfo.model_validate(data)
case DataSource.canner:
return CannerConnectionInfo.model_validate(data)
case DataSource.clickhouse:
Expand Down Expand Up @@ -312,7 +315,7 @@ def get_athena_connection(info: AthenaConnectionInfo) -> BaseBackend:
return ibis.athena.connect(**kwargs)

@staticmethod
def get_bigquery_connection(info: BigQueryConnectionInfo) -> BaseBackend:
def get_bigquery_connection(info: BigQueryDatasetConnectionInfo) -> BaseBackend:
credits_json = loads(
base64.b64decode(info.credentials.get_secret_value()).decode("utf-8")
)
Expand Down
Loading
Loading