diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index 96f89d6b9..e70539117 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -5,6 +5,8 @@ from pydantic import BaseModel, Field, SecretStr +from app.model.error import ErrorCode, WrenError + manifest_str_field = Field(alias="manifestStr", description="Base64 manifest") connection_info_field = Field(alias="connectionInfo") @@ -28,7 +30,7 @@ class QueryDTO(BaseModel): class QueryBigQueryDTO(QueryDTO): - connection_info: BigQueryConnectionInfo = connection_info_field + connection_info: BigQueryConnectionUnion = connection_info_field class QueryAthenaDTO(QueryDTO): @@ -100,10 +102,6 @@ class ConnectionUrl(BaseConnectionInfo): class BigQueryConnectionInfo(BaseConnectionInfo): - project_id: SecretStr = Field(description="GCP project id", examples=["my-project"]) - dataset_id: SecretStr = Field( - description="BigQuery dataset id", examples=["my_dataset"] - ) credentials: SecretStr = Field( description="Base64 encode `credentials.json`", examples=["eyJ..."] ) @@ -112,6 +110,55 @@ class BigQueryConnectionInfo(BaseConnectionInfo): default=None, ) + def get_billing_project_id(self) -> str | None: + raise WrenError( + ErrorCode.NOT_IMPLEMENTED, + "get_billing_project_id not implemented by base class", + ) + + +class BigQueryDatasetConnectionInfo(BigQueryConnectionInfo): + bigquery_type: Literal["dataset"] = "dataset" + project_id: SecretStr = Field(description="GCP project id", examples=["my-project"]) + dataset_id: SecretStr = Field( + description="BigQuery dataset id", examples=["my_dataset"] + ) + + def get_billing_project_id(self): + return self.project_id.get_secret_value() + + def __hash__(self): + return hash((self.project_id, self.dataset_id, self.credentials)) + + +class BigQueryProjectConnectionInfo(BigQueryConnectionInfo): + bigquery_type: Literal["project"] = "project" + region: SecretStr = Field( + description="the region of your BigQuery connection", examples=["US"] + ) + billing_project_id: SecretStr = Field( + description="the billing project id of your BigQuery connection", + examples=["billing-project-1"], + ) + + def get_billing_project_id(self): + return self.billing_project_id.get_secret_value() + + def __hash__(self): + return hash( + ( + self.region, + self.billing_project_id, + self.credentials, + ) + ) + + +BigQueryConnectionUnion = Annotated[ + Union[BigQueryDatasetConnectionInfo, BigQueryProjectConnectionInfo], + Field(discriminator="bigquery_type", default="dataset"), +] + class AthenaConnectionInfo(BaseConnectionInfo): s3_staging_dir: SecretStr = Field( @@ -570,7 +617,8 @@ class GcsFileConnectionInfo(BaseConnectionInfo): ConnectionInfo = ( AthenaConnectionInfo - | BigQueryConnectionInfo + | BigQueryDatasetConnectionInfo + | BigQueryProjectConnectionInfo | CannerConnectionInfo | ClickHouseConnectionInfo | ConnectionUrl diff --git a/ibis-server/app/model/connector.py b/ibis-server/app/model/connector.py index f556ee2c0..092454e01 100644 --- a/ibis-server/app/model/connector.py +++ b/ibis-server/app/model/connector.py @@ -521,7 +521,13 @@ def __init__(self, connection_info: ConnectionInfo): "https://www.googleapis.com/auth/cloud-platform", ] ) - client = bigquery.Client(credentials=credentials) + client = bigquery.Client( + credentials=credentials, + project=connection_info.get_billing_project_id(), + ) + job_config = bigquery.QueryJobConfig() + job_config.job_timeout_ms = self.connection_info.job_timeout_ms + client.default_query_job_config = job_config self.connection = client @tracer.start_as_current_span("connector_query", kind=trace.SpanKind.CLIENT) diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index d43a1f13d..aba155b47 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -16,7 +16,8 @@ from app.model import ( AthenaConnectionInfo, - BigQueryConnectionInfo, + BigQueryDatasetConnectionInfo, + BigQueryProjectConnectionInfo, CannerConnectionInfo, ClickHouseConnectionInfo, ConnectionInfo, @@ -152,7 +153,9 @@ def _build_connection_info(self, data: dict) -> ConnectionInfo: case DataSource.athena: return AthenaConnectionInfo.model_validate(data) case DataSource.bigquery: - return BigQueryConnectionInfo.model_validate(data) + if "bigquery_type" in data and data["bigquery_type"] == "project": + return BigQueryProjectConnectionInfo.model_validate(data) + return BigQueryDatasetConnectionInfo.model_validate(data) case DataSource.canner: return CannerConnectionInfo.model_validate(data) case DataSource.clickhouse: @@ -312,7 +315,7 @@ def get_athena_connection(info: AthenaConnectionInfo) -> BaseBackend: return ibis.athena.connect(**kwargs) @staticmethod - def get_bigquery_connection(info: BigQueryConnectionInfo) -> BaseBackend: + def get_bigquery_connection(info: BigQueryDatasetConnectionInfo) -> BaseBackend: credits_json = loads( base64.b64decode(info.credentials.get_secret_value()).decode("utf-8") ) diff --git a/ibis-server/app/model/metadata/bigquery.py b/ibis-server/app/model/metadata/bigquery.py index 4db2d4b8c..b203ba169 100644 --- a/ibis-server/app/model/metadata/bigquery.py +++ b/ibis-server/app/model/metadata/bigquery.py @@ -1,11 +1,20 @@ +from google.api_core.exceptions import Forbidden, NotFound from loguru import logger -from app.model import BigQueryConnectionInfo +from app.model import ( + BigQueryDatasetConnectionInfo, + BigQueryProjectConnectionInfo, +) from app.model.connector import BigQueryConnector +from app.model.error import ErrorCode, WrenError from app.model.metadata.dto import ( + BigQueryFilterInfo, + Catalog, Column, Constraint, ConstraintType, + FilterInfo, + ProjectDatasets, RustWrenEngineColumnType, Table, TableProperties, @@ -31,166 +40,441 @@ "timestamp": RustWrenEngineColumnType.TIMESTAMPTZ, } +BIGQUERY_PUBLIC_DATASET_PROJECT_ID = "bigquery-public-data" + class BigQueryMetadata(Metadata): - def __init__(self, connection_info: BigQueryConnectionInfo): + def __init__(self, connection_info: BigQueryDatasetConnectionInfo): super().__init__(connection_info) self.connection = BigQueryConnector(connection_info) - def get_table_list(self) -> list[Table]: - dataset_id = self.connection_info.dataset_id.get_secret_value() - - # filter out columns with GEOGRAPHY & RANGE types - sql = f""" - SELECT - c.table_catalog, - c.table_schema, - c.table_name, - c.column_name, - c.ordinal_position, - c.is_nullable, - c.is_generated, - c.generation_expression, - c.is_stored, - c.is_hidden, - c.is_updatable, - c.is_system_defined, - c.is_partitioning_column, - c.clustering_ordinal_position, - c.collation_name, - c.column_default, - c.rounding_mode, - cf.data_type, - cf.field_path, - cf.description AS column_description, - table_options.option_value AS table_description - FROM {dataset_id}.INFORMATION_SCHEMA.COLUMNS c - JOIN {dataset_id}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS cf - ON cf.table_name = c.table_name - AND cf.column_name = c.column_name - LEFT JOIN {dataset_id}.INFORMATION_SCHEMA.TABLE_OPTIONS table_options - ON c.table_name = table_options.table_name AND table_options.OPTION_NAME = 'description' - WHERE cf.data_type != 'GEOGRAPHY' - AND cf.data_type NOT LIKE 'RANGE%' - ORDER BY cf.field_path ASC - """ - response = self.connection.query(sql).to_pandas().to_dict(orient="records") - - def get_column(row) -> Column: - return Column( - # field_path supports both column & nested column - name=row["field_path"], - type=row["data_type"], - notNull=row["is_nullable"].lower() == "no", - description=row["column_description"], - properties={}, - nestedColumns=[] if has_nested_columns(row) else None, - ) + def get_table_list( + self, + filter_info: FilterInfo | None = None, + limit: int | None = None, + ) -> list[Table]: + project_datasets = self._get_project_datasets(filter_info) + self._validate_project_datasets_for_table_list(project_datasets) - def get_table(row) -> Table: - return Table( - name=table_name, - description=row["table_description"], - columns=[], - properties=TableProperties( - schema=row["table_schema"], - catalog=row["table_catalog"], - table=row["table_name"], - ), - primaryKey="", - ) + def build_list_table_sql(qualifier: str) -> str: + sql = f""" + SELECT + c.table_catalog, + c.table_schema, + c.table_name, + table_options.option_value AS table_description, + ARRAY_AGG(STRUCT( + c.column_name, + c.ordinal_position, + c.is_nullable, + c.is_generated, + c.generation_expression, + c.is_stored, + c.is_hidden, + c.is_updatable, + c.is_system_defined, + c.is_partitioning_column, + c.clustering_ordinal_position, + c.collation_name, + c.column_default, + c.rounding_mode, + cf.data_type, + cf.field_path, + cf.description AS column_description + ) ORDER BY cf.field_path ASC + ) AS columns + FROM {qualifier}.INFORMATION_SCHEMA.COLUMNS c + JOIN {qualifier}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS cf + ON cf.table_name = c.table_name + AND cf.column_name = c.column_name + LEFT JOIN {qualifier}.INFORMATION_SCHEMA.TABLE_OPTIONS table_options + ON c.table_name = table_options.table_name AND table_options.OPTION_NAME = 'description' + WHERE cf.data_type != 'GEOGRAPHY' + AND cf.data_type NOT LIKE 'RANGE%' + """ + if ( + filter_info + and hasattr(filter_info, "filter_pattern") + and filter_info.filter_pattern + ): + sql += f"\nAND REGEXP_CONTAINS(c.table_name, r'{filter_info.filter_pattern}') " - def is_root_column(row) -> bool: - return "." not in row["field_path"] - - def has_nested_columns(row) -> bool: - return "STRUCT" in row["data_type"] - - # eg: - # if I would like to find the parent_column of "messages.data.elements.aspectRatio" - # the output should be the column -> {name: "messages.data.elements", ...} - def find_parent_column(column_metadata, root_column) -> Column: - parent_column_names = column_metadata["field_path"].split(".")[1:-1] - if len(parent_column_names) == 0: - return root_column - col_ref = root_column - cur_column_name = root_column.name - for partial_column_name in parent_column_names: - cur_column_name = cur_column_name + "." + partial_column_name - col_ref = next( - filter( - lambda column: column.name == cur_column_name, - col_ref.nestedColumns, - ), - None, + if ( + isinstance(self.connection_info, BigQueryProjectConnectionInfo) + and project_dataset.dataset_ids + and len(project_dataset.dataset_ids) > 0 + ): + dataset_list = ", ".join( + [f"'{ds_id}'" for ds_id in project_dataset.dataset_ids] ) - if not col_ref: - return None - return col_ref + sql += f"\nAND c.table_schema in ({dataset_list})" + sql += "\nGROUP BY c.table_catalog, c.table_schema, c.table_name, table_description" + return sql + + project_sqls = [] + for project_dataset in project_datasets: + if project_dataset.project_id == BIGQUERY_PUBLIC_DATASET_PROJECT_ID: + for dataset in project_dataset.dataset_ids: + self._validate_dataset_region(project_dataset.project_id, dataset) + qualifier = f"`{project_dataset.project_id}.{dataset}`" + sql = build_list_table_sql(qualifier) + project_sqls.append(sql) + else: + qualifier = f"`{project_dataset.project_id}.{self._get_schema_qualifier(project_dataset.dataset_ids)}`" + sql = build_list_table_sql(qualifier) + project_sqls.append(sql) + + union_sql = "\nUNION ALL\n".join(project_sqls) + + if limit is not None: + union_sql += f"\n LIMIT {limit}" + + logger.debug(f"get table list SQL: {union_sql}") - unique_tables = {} + response = self.connection.query(union_sql).to_pylist() - for column_metadata in response: + table_list = [] + is_multiple_projects = len(project_datasets) > 1 + # if multiple datasets exist, we need to include schema in the table name (include catalog if multiple projects exist) + is_multiple_datasets = is_multiple_projects or ( + isinstance(self.connection_info, BigQueryProjectConnectionInfo) + and len(project_datasets[0].dataset_ids or []) > 1 + ) + + for table_metadata in response: # generate unique table name - table_name = column_metadata["table_name"] - # init table if not exists - if table_name not in unique_tables: - unique_tables[table_name] = get_table(column_metadata) + table_name = self._format_compact_table_name( + table_metadata["table_catalog"] if is_multiple_projects else None, + table_metadata["table_schema"] if is_multiple_datasets else None, + table_metadata["table_name"], + ) + table = self.get_table(table_metadata, table_name) - current_table = unique_tables[table_name] # if column is normal column, add to table - if is_root_column(column_metadata): - current_table.columns.append(get_column(column_metadata)) - # if column is nested column, find the parent nested column, and append to the nestedColumns of the parent column - else: - root_column_name = column_metadata["field_path"].split(".")[0] - root_column = next( - filter( - lambda column: column.name == root_column_name, - current_table.columns, - ), - None, + + for column_metadata in table_metadata["columns"]: + if self.is_root_column(column_metadata): + table.columns.append(self.get_column(column_metadata)) + # if column is nested column, find the parent nested column, and append to the nestedColumns of the parent column + else: + root_column_name = column_metadata["field_path"].split(".")[0] + root_column = next( + filter( + lambda column: column.name == root_column_name, + table.columns, + ), + None, + ) + if not root_column: + continue + parent_nested_column = self.find_parent_column( + column_metadata, root_column + ) + if parent_nested_column: + if not parent_nested_column.nestedColumns: + parent_nested_column.nestedColumns = [] + parent_nested_column.nestedColumns.append( + self.get_column(column_metadata) + ) + table_list.append(table) + + if len(table_list) == 0: + raise WrenError( + ErrorCode.NOT_FOUND, + "No tables found in the specified dataset. Please check your permissions and name of dataset and project.", + ) + + return table_list + + def _validate_dataset_region(self, project_id: str, dataset_id: str) -> bool: + dataset = self.connection.connection.get_dataset(f"{project_id}.{dataset_id}") + dataset_location = dataset.location.lower() + connection_region = self.connection_info.region.get_secret_value().lower() + if dataset_location != connection_region: + raise WrenError( + ErrorCode.INVALID_CONNECTION_INFO, + f"Dataset {project_id}.{dataset_id} is in region {dataset_location}, which does not match the connection region {connection_region}.", + ) + + def get_schema_list(self, filter_info=None, limit=None): + if isinstance(self.connection_info, BigQueryDatasetConnectionInfo): + return [ + Catalog( + name=self.connection_info.get_billing_project_id(), + schemas=[self.connection_info.dataset_id.get_secret_value()], ) - if not root_column: + ] + project_set = set() + project_set.add(self.connection_info.get_billing_project_id()) + project_datasets = self._get_project_datasets(filter_info) + used_bigquery_public_data = False + if project_datasets is not None: + for pd in project_datasets: + if pd.project_id == BIGQUERY_PUBLIC_DATASET_PROJECT_ID: + used_bigquery_public_data = True continue - parent_nested_column = find_parent_column(column_metadata, root_column) - if parent_nested_column: - parent_nested_column.nestedColumns.append( - get_column(column_metadata) - ) + project_set.add(pd.project_id) - return list(unique_tables.values()) - - def get_constraints(self) -> list[Constraint]: - dataset_id = self.connection_info.dataset_id.get_secret_value() - sql = f""" - SELECT - CONCAT(ccu.table_name, '_', ccu.column_name, '_', kcu.table_name, '_', kcu.column_name) as constraintName, - ccu.table_name as constraintTable, ccu.column_name constraintColumn, - kcu.table_name as constraintedTable, kcu.column_name as constraintedColumn, - FROM {dataset_id}.INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE ccu - JOIN {dataset_id}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu - ON ccu.constraint_name = kcu.constraint_name - JOIN {dataset_id}.INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc - ON ccu.constraint_name = tc.constraint_name - WHERE tc.constraint_type = 'FOREIGN KEY' - """ - response = self.connection.query(sql).to_pandas().to_dict(orient="records") + project_sqls = [] + for project in list(project_set): + qualifier = ( + f"`{project}.region-{self.connection_info.region.get_secret_value()}`" + ) + # filter out columns with GEOGRAPHY & RANGE types + sql = f""" + SELECT + catalog_name, + schema_name + FROM {qualifier}.INFORMATION_SCHEMA.SCHEMATA + """ + project_sqls.append(sql) - constraints = [] + union_sql = "\nUNION ALL\n".join(project_sqls) + + if limit is not None: + union_sql += f"\n LIMIT {limit}" + + grouping_sql = f""" + SELECT catalog_name, array_agg(schema_name) AS schema_names + FROM ( + {union_sql} + ) + GROUP BY catalog_name + """ + logger.debug(f"get schema list SQL: {grouping_sql}") + + response = self.connection.query(grouping_sql).to_pylist() + + if len(response) == 0: + raise WrenError( + ErrorCode.NOT_FOUND, + "No tables found in the specified dataset. Please check your permissions and name of dataset and project.", + ) + project_list = [] for row in response: - constraints.append( - Constraint( - constraintName=row["constraintName"], - constraintTable=row["constraintTable"], - constraintColumn=row["constraintColumn"], - constraintedTable=row["constraintedTable"], - constraintedColumn=row["constraintedColumn"], - constraintType=ConstraintType.FOREIGN_KEY, + project_list.append( + Catalog(name=row["catalog_name"], schemas=row["schema_names"]) + ) + + if used_bigquery_public_data: + public_data_set = self.connection.connection.list_datasets( + project=BIGQUERY_PUBLIC_DATASET_PROJECT_ID + ) + dataset_list = [dataset.dataset_id for dataset in public_data_set] + project_list.append( + Catalog( + name=BIGQUERY_PUBLIC_DATASET_PROJECT_ID, + schemas=dataset_list, + ) + ) + + return project_list + + def get_column(self, row) -> Column: + return Column( + # field_path supports both column & nested column + name=row["field_path"], + type=row["data_type"], + notNull=row["is_nullable"].lower() == "no", + description=row["column_description"], + properties={ + "column_order": row["ordinal_position"], + }, + nestedColumns=[] if self.has_nested_columns(row) else None, + ) + + def get_table(self, row, table_name: str) -> Table: + return Table( + name=table_name, + description=row["table_description"], + columns=[], + properties=TableProperties( + schema=row["table_schema"], + catalog=row["table_catalog"], + table=row["table_name"], + ), + primaryKey="", + ) + + def is_root_column(self, row) -> bool: + return "." not in row["field_path"] + + def has_nested_columns(self, row) -> bool: + return "STRUCT" in row["data_type"] + + def _format_compact_table_name( + self, + catalog: str | None, + schema: str | None, + table: str, + delimiter: str = ".", + ) -> str: + if schema is None or schema == "": + return f"{table}" + if catalog is None or catalog == "": + return f"{schema}{delimiter}{table}" + return f"{catalog}{delimiter}{schema}{delimiter}{table}" + + # eg: + # if I would like to find the parent_column of "messages.data.elements.aspectRatio" + # the output should be the column -> {name: "messages.data.elements", ...} + def find_parent_column(self, column_metadata, root_column) -> Column: + parent_column_names = column_metadata["field_path"].split(".")[1:-1] + if len(parent_column_names) == 0: + return root_column + col_ref = root_column + cur_column_name = root_column.name + for partial_column_name in parent_column_names: + cur_column_name = cur_column_name + "." + partial_column_name + col_ref = next( + filter( + lambda column: column.name == cur_column_name, + col_ref.nestedColumns, + ), + None, + ) + if not col_ref: + return None + return col_ref + + def get_constraints( + self, filter_info: FilterInfo | None = None + ) -> list[Constraint]: + project_datasets = self._get_project_datasets(filter_info) + constraints = [] + + for project_dataset in project_datasets: + project_id = project_dataset.project_id + schema_qualifier = self._get_schema_qualifier(project_dataset.dataset_ids) + qualifier = f"`{project_id}.{schema_qualifier}`" + sql = f""" + SELECT + tc.table_catalog as start_catalog, + tc.table_schema as start_schema, + tc.table_name as start_table, + kcu.column_name as start_column, + ccu.table_catalog as end_catalog, + ccu.table_schema as end_schema, + ccu.table_name as end_table, + ccu.column_name as end_column + FROM {qualifier}.INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc + JOIN {qualifier}.INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE ccu + ON tc.constraint_catalog = ccu.constraint_catalog + AND tc.constraint_schema = ccu.constraint_schema + AND tc.constraint_name = ccu.constraint_name + JOIN {qualifier}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu + ON tc.constraint_catalog = kcu.constraint_catalog + AND tc.constraint_schema = kcu.constraint_schema + AND tc.constraint_name = kcu.constraint_name + WHERE tc.constraint_type = 'FOREIGN KEY' + """ + + logger.debug(f"get constraints SQL: {sql}") + + try: + response = self.connection.query(sql).to_pylist() + except (Forbidden, NotFound): + logger.warning( + f"Can't access dataset {schema_qualifier} in project {project_id}" ) + continue + + is_multiple_projects = len(project_datasets) > 1 + # if multiple datasets exist, we need to include schema in the table name (include catalog if multiple projects exist) + is_multiple_datasets = is_multiple_projects or ( + isinstance(self.connection_info, BigQueryProjectConnectionInfo) + and len(project_dataset.dataset_ids or []) > 1 ) + + for row in response: + start_catalog = row["start_catalog"] if is_multiple_projects else None + start_schema = row["start_schema"] if is_multiple_datasets else None + start_table = self._format_compact_table_name( + start_catalog, start_schema, row["start_table"] + ) + end_catalog = row["end_catalog"] if is_multiple_projects else None + end_schema = row["end_schema"] if is_multiple_datasets else None + end_table = self._format_compact_table_name( + end_catalog, end_schema, row["end_table"] + ) + + start_table_underline = self._format_compact_table_name( + start_catalog, start_schema, row["start_table"], delimiter="_" + ) + end_table_underline = self._format_compact_table_name( + end_catalog, end_schema, row["end_table"], delimiter="_" + ) + constraints.append( + Constraint( + constraintName=f"{start_table_underline}_{row['start_column']}_{end_table_underline}_{row['end_column']}", + constraintTable=start_table, + constraintColumn=row["start_column"], + constraintedTable=end_table, + constraintedColumn=row["end_column"], + constraintType=ConstraintType.FOREIGN_KEY, + ) + ) return constraints + def _get_project_datasets( + self, filter: FilterInfo | None + ) -> list["ProjectDatasets"]: + if isinstance(self.connection_info, BigQueryDatasetConnectionInfo): + return [ + ProjectDatasets( + projectId=self.connection_info.get_billing_project_id(), + datasetIds=[self.connection_info.dataset_id.get_secret_value()], + ) + ] + elif isinstance(self.connection_info, BigQueryProjectConnectionInfo): + if filter is not None and isinstance(filter, BigQueryFilterInfo): + return filter.projects + return [ + ProjectDatasets( + projectId=self.connection_info.get_billing_project_id(), + datasetIds=None, + ) + ] + raise WrenError( + ErrorCode.GENERIC_INTERNAL_ERROR, "Invalid connection info type" + ) + + def _get_schema_qualifier(self, dataset_ids): + if isinstance(self.connection_info, BigQueryDatasetConnectionInfo): + return dataset_ids[0] + elif isinstance(self.connection_info, BigQueryProjectConnectionInfo): + return f"region-{self.connection_info.region.get_secret_value()}" + raise WrenError( + ErrorCode.GENERIC_INTERNAL_ERROR, "Invalid connection info type" + ) + + def _validate_project_datasets_for_table_list( + self, project_datasets: list["ProjectDatasets"] + ): + if len(project_datasets) == 0: + raise WrenError( + ErrorCode.INVALID_CONNECTION_INFO, + "At least one project and dataset must be specified", + ) + for project_dataset in project_datasets: + if project_dataset.project_id is None or project_dataset.project_id == "": + raise WrenError( + ErrorCode.INVALID_CONNECTION_INFO, "project id should not be empty" + ) + if ( + project_dataset.dataset_ids is None + or len(project_dataset.dataset_ids) == 0 + ): + raise WrenError( + ErrorCode.INVALID_CONNECTION_INFO, "dataset ids should not be empty" + ) + for dataset_id in project_dataset.dataset_ids: + if dataset_id is None or dataset_id == "": + raise WrenError( + ErrorCode.INVALID_CONNECTION_INFO, + "dataset id should not be empty", + ) + def get_version(self) -> str: return "Follow BigQuery release version" diff --git a/ibis-server/app/model/metadata/dto.py b/ibis-server/app/model/metadata/dto.py index 70524121a..685313617 100644 --- a/ibis-server/app/model/metadata/dto.py +++ b/ibis-server/app/model/metadata/dto.py @@ -4,10 +4,36 @@ from pydantic import BaseModel, Field from app.model import ConnectionInfo +from app.model.data_source import DataSource + + +class V2MetadataDTO(BaseModel): + connection_info: dict[str, Any] | ConnectionInfo = Field(alias="connectionInfo") + + +class FilterInfo(BaseModel): + pass class MetadataDTO(BaseModel): connection_info: dict[str, Any] | ConnectionInfo = Field(alias="connectionInfo") + table_limit: int | None = Field(alias="limit", default=None) + filter_info: dict[str, Any] | None = Field(alias="filterInfo", default=None) + + +class BigQueryFilterInfo(FilterInfo): + projects: list["ProjectDatasets"] | None = None + + +class ProjectDatasets(BaseModel): + project_id: str = Field(alias="projectId") + dataset_ids: list[str] | None = Field(alias="datasetIds", default=None) + + +def get_filter_info(data_source: DataSource, info: dict[str, Any]) -> FilterInfo | None: + if data_source == DataSource.bigquery: + return BigQueryFilterInfo(**info) + return None class RustWrenEngineColumnType(Enum): @@ -85,6 +111,11 @@ class Table(BaseModel): primaryKey: str | None = None +class Catalog(BaseModel): + name: str + schemas: list[str] + + class ConstraintType(Enum): PRIMARY_KEY = "PRIMARY KEY" FOREIGN_KEY = "FOREIGN KEY" diff --git a/ibis-server/app/model/metadata/metadata.py b/ibis-server/app/model/metadata/metadata.py index 244ba60cd..cfdebe081 100644 --- a/ibis-server/app/model/metadata/metadata.py +++ b/ibis-server/app/model/metadata/metadata.py @@ -1,21 +1,24 @@ -from abc import ABC, abstractmethod - from app.model import ConnectionInfo -from app.model.metadata.dto import Constraint, Table +from app.model.error import ErrorCode, WrenError +from app.model.metadata.dto import Catalog, Constraint, FilterInfo, Table -class Metadata(ABC): +class Metadata: def __init__(self, connection_info: ConnectionInfo): self.connection_info = connection_info - @abstractmethod - def get_table_list(self) -> list[Table]: - pass + def get_table_list( + self, filter_info: FilterInfo | None = None, limit: int | None = None + ) -> list[Table]: + raise WrenError(ErrorCode.NOT_IMPLEMENTED, "get_table_list not implemented") - @abstractmethod def get_constraints(self) -> list[Constraint]: - pass + raise WrenError(ErrorCode.NOT_IMPLEMENTED, "get_constraints not implemented") - @abstractmethod def get_version(self) -> str: - pass + raise WrenError(ErrorCode.NOT_IMPLEMENTED, "get_version not implemented") + + def get_schema_list( + self, filter_info: FilterInfo | None = None, limit: int | None = None + ) -> list[Catalog]: + raise WrenError(ErrorCode.NOT_IMPLEMENTED, "get_schema_list not implemented") diff --git a/ibis-server/app/routers/v2/connector.py b/ibis-server/app/routers/v2/connector.py index e6a28be16..2b1d4dc73 100644 --- a/ibis-server/app/routers/v2/connector.py +++ b/ibis-server/app/routers/v2/connector.py @@ -18,6 +18,7 @@ from app.mdl.rewriter import Rewriter from app.mdl.substitute import ModelSubstitute from app.model import ( + BigQueryProjectConnectionInfo, DryPlanDTO, QueryDTO, TranspileDTO, @@ -25,7 +26,12 @@ ) from app.model.connector import Connector from app.model.data_source import DataSource -from app.model.metadata.dto import Constraint, MetadataDTO, Table +from app.model.error import ErrorCode, WrenError +from app.model.metadata.dto import ( + Constraint, + Table, + V2MetadataDTO, +) from app.model.metadata.factory import MetadataFactory from app.model.validator import Validator from app.query_cache import QueryCacheManager @@ -258,7 +264,7 @@ async def validate( ) async def get_table_list( data_source: DataSource, - dto: MetadataDTO, + dto: V2MetadataDTO, headers: Annotated[Headers, Depends(get_wren_headers)] = None, ) -> list[Table]: span_name = f"v2_metadata_tables_{data_source}" @@ -269,8 +275,14 @@ async def get_table_list( connection_info = data_source.get_connection_info( dto.connection_info, dict(headers) ) - metadata = MetadataFactory.get_metadata(data_source, connection_info) - return await execute_get_table_list_with_timeout(metadata) + if isinstance(connection_info, BigQueryProjectConnectionInfo): + raise WrenError( + ErrorCode.INVALID_CONNECTION_INFO, + "BigQuery project-level connection info is only supported by v3 API for metadata table list retrieval.", + ) + else: + metadata = MetadataFactory.get_metadata(data_source, connection_info) + return await execute_get_table_list_with_timeout(metadata) @router.post( @@ -281,7 +293,7 @@ async def get_table_list( ) async def get_constraints( data_source: DataSource, - dto: MetadataDTO, + dto: V2MetadataDTO, headers: Annotated[Headers, Depends(get_wren_headers)] = None, ) -> list[Constraint]: span_name = f"v2_metadata_constraints_{data_source}" @@ -292,6 +304,11 @@ async def get_constraints( connection_info = data_source.get_connection_info( dto.connection_info, dict(headers) ) + if isinstance(connection_info, BigQueryProjectConnectionInfo): + raise WrenError( + ErrorCode.INVALID_CONNECTION_INFO, + "BigQuery project-level connection info is only supported by v3 API for metadata constraints retrieval.", + ) metadata = MetadataFactory.get_metadata(data_source, connection_info) return await execute_get_constraints_with_timeout(metadata) @@ -303,7 +320,7 @@ async def get_constraints( ) async def get_db_version( data_source: DataSource, - dto: MetadataDTO, + dto: V2MetadataDTO, headers: Annotated[Headers, Depends(get_wren_headers)] = None, ) -> str: connection_info = data_source.get_connection_info( diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index d0c2b454c..e6c711ca3 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -32,6 +32,8 @@ from app.model.connector import Connector from app.model.data_source import DataSource from app.model.error import DatabaseTimeoutError +from app.model.metadata.dto import Catalog, MetadataDTO, Table, get_filter_info +from app.model.metadata.factory import MetadataFactory from app.model.validator import Validator from app.query_cache import QueryCacheManager from app.routers import v2 @@ -40,6 +42,8 @@ append_fallback_context, build_context, execute_dry_run_with_timeout, + execute_get_schema_list_with_timeout, + execute_get_table_list_with_timeout, execute_query_with_timeout, execute_validate_with_timeout, pushdown_limit, @@ -547,3 +551,55 @@ async def get_sql_knowledge( "text_to_sql_rule": knowledge_manager.get_text_to_sql_rule(), "instructions": knowledge_manager.get_sql_instructions(), } + + +@router.post( + "/{data_source}/metadata/tables", + description="get the table metadata of the specified data source", +) +async def get_table_list( + data_source: DataSource, + dto: MetadataDTO, + headers: Annotated[Headers, Depends(get_wren_headers)], +) -> list[Table]: + span_name = f"v3_get_table_list_{data_source}" + with tracer.start_as_current_span( + name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) + ) as span: + set_attribute(headers, span) + connection_info = data_source.get_connection_info( + dto.connection_info, dict(headers) + ) + metadata = MetadataFactory.get_metadata(data_source, connection_info) + filter_info = get_filter_info(data_source, dto.filter_info or {}) + return await execute_get_table_list_with_timeout( + metadata, + filter_info, + dto.table_limit, + ) + + +@router.post( + "/{data_source}/metadata/schemas", + description="get the schema metadata of the specified data source", +) +async def get_schema_list( + data_source: DataSource, + dto: MetadataDTO, + headers: Annotated[Headers, Depends(get_wren_headers)], +) -> list[Catalog]: + span_name = f"v3_get_schema_list_{data_source}" + with tracer.start_as_current_span( + name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) + ) as span: + set_attribute(headers, span) + connection_info = data_source.get_connection_info( + dto.connection_info, dict(headers) + ) + metadata = MetadataFactory.get_metadata(data_source, connection_info) + filter_info = get_filter_info(data_source, dto.filter_info or {}) + return await execute_get_schema_list_with_timeout( + metadata, + filter_info, + dto.table_limit, + ) diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index d134e874a..631eea730 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -42,6 +42,8 @@ class ClickHouseDbError(Exception): ) from app.model.data_source import DataSource from app.model.error import DatabaseTimeoutError +from app.model.metadata.bigquery import BigQueryMetadata +from app.model.metadata.dto import FilterInfo from app.model.metadata.metadata import Metadata tracer = trace.get_tracer(__name__) @@ -350,14 +352,52 @@ async def execute_dry_run_with_timeout(connector, sql: str): async def execute_get_table_list_with_timeout( metadata: Metadata, + filter_info: FilterInfo | None = None, + limit: int | None = None, ): """Get the list of tables with a timeout control.""" + if isinstance(metadata, BigQueryMetadata): + return await execute_with_timeout( + asyncio.create_task( + asyncio.to_thread( + metadata.get_table_list, + filter_info, + limit, + ) + ), + "Get Table List", + ) + return await execute_with_timeout( asyncio.to_thread(metadata.get_table_list), "Get Table List", ) +async def execute_get_schema_list_with_timeout( + metadata: Metadata, + filter_info: FilterInfo | None = None, + limit: int | None = None, +): + """Get the list of tables with a timeout control.""" + if isinstance(metadata, BigQueryMetadata): + return await execute_with_timeout( + asyncio.create_task( + asyncio.to_thread( + metadata.get_schema_list, + filter_info, + limit, + ) + ), + "Get Schema List", + ) + + return await execute_with_timeout( + asyncio.to_thread(metadata.get_schema_list), + "Get Schema List", + ) + + async def execute_get_constraints_with_timeout( metadata: Metadata, ): diff --git a/ibis-server/tests/routers/v2/connector/test_bigquery.py b/ibis-server/tests/routers/v2/connector/test_bigquery.py index 223c7629b..d4c083397 100644 --- a/ibis-server/tests/routers/v2/connector/test_bigquery.py +++ b/ibis-server/tests/routers/v2/connector/test_bigquery.py @@ -392,6 +392,24 @@ async def test_metadata_list_constraints(client): assert response.status_code == 200 +async def test_metadata_list_unsupported_project_connection(client): + multi_dataset_connection_info = { + "bigquery_type": "project", + "billing_project_id": os.getenv("TEST_BIG_QUERY_PROJECT_ID"), + "region": os.getenv("TEST_BIG_QUERY_REGION", "asia-east1"), + "credentials": os.getenv("TEST_BIG_QUERY_CREDENTIALS_BASE64_JSON"), + } + response = await client.post( + url=f"{base_url}/metadata/tables", + json={"connectionInfo": multi_dataset_connection_info}, + ) + assert response.status_code == 422 + assert ( + "BigQuery project-level connection info is only supported by v3 API for metadata table list retrieval." + in response.text + ) + + async def test_metadata_db_version(client): response = await client.post( url=f"{base_url}/metadata/version", diff --git a/ibis-server/tests/routers/v3/connector/bigquery/conftest.py b/ibis-server/tests/routers/v3/connector/bigquery/conftest.py index 1fe6c7232..f47578771 100644 --- a/ibis-server/tests/routers/v3/connector/bigquery/conftest.py +++ b/ibis-server/tests/routers/v3/connector/bigquery/conftest.py @@ -30,6 +30,16 @@ def connection_info(): } +@pytest.fixture(scope="session") +def project_connection_info(): + return { + "bigquery_type": "project", + "billing_project_id": os.getenv("TEST_BIG_QUERY_PROJECT_ID"), + "region": os.getenv("TEST_BIG_QUERY_REGION", "asia-east1"), + "credentials": os.getenv("TEST_BIG_QUERY_CREDENTIALS_BASE64_JSON"), + } + + @pytest.fixture(autouse=True) def set_remote_function_list_path(): config = get_config() diff --git a/ibis-server/tests/routers/v3/connector/bigquery/test_query.py b/ibis-server/tests/routers/v3/connector/bigquery/test_query.py index 817c35917..68dffe9a0 100644 --- a/ibis-server/tests/routers/v3/connector/bigquery/test_query.py +++ b/ibis-server/tests/routers/v3/connector/bigquery/test_query.py @@ -1,10 +1,13 @@ import base64 +import os import time +from re import split import orjson import pytest from app.dependencies import X_WREN_FALLBACK_DISABLE, X_WREN_VARIABLE_PREFIX +from app.model.metadata.bigquery import BIGQUERY_PUBLIC_DATASET_PROJECT_ID from tests.routers.v3.connector.bigquery.conftest import base_url manifest = { @@ -162,6 +165,94 @@ async def test_query(client, manifest_str, connection_info): "dst_utc_minus_4": "timestamp[us, tz=UTC]", } + multi_dataset_connection_info = { + "bigquery_type": "project", + "billing_project_id": os.getenv("TEST_BIG_QUERY_PROJECT_ID"), + "region": os.getenv("TEST_BIG_QUERY_REGION", "asia-east1"), + "credentials": os.getenv("TEST_BIG_QUERY_CREDENTIALS_BASE64_JSON"), + } + + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": multi_dataset_connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM wren.public.orders LIMIT 1", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + + +async def test_query_cross_project_dataset(client): + manifest = { + "catalog": "wren", + "schema": "public", + "dataSource": "bigquery", + "models": [ + { + "name": "orders", + "tableReference": { + "catalog": "wrenai", + "schema": "tpch_tiny_us", + "table": "orders", + }, + "columns": [ + { + "name": "o_orderkey", + "type": "integer", + }, + { + "name": "o_custkey", + "type": "integer", + }, + ], + }, + { + "name": "311_service_requests", + "tableReference": { + "catalog": "bigquery-public-data", + "schema": "austin_311", + "table": "311_service_requests", + }, + "columns": [ + { + "name": "city", + "type": "string", + }, + { + "name": "unique_key", + "type": "string", + }, + ], + }, + ], + } + + manifest_str = base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + + connection_info = { + "bigquery_type": "project", + "billing_project_id": os.getenv("TEST_BIG_QUERY_PROJECT_ID"), + "region": "US", + "credentials": os.getenv("TEST_BIG_QUERY_CREDENTIALS_BASE64_JSON"), + } + + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT o.o_orderkey, s.city FROM orders o CROSS JOIN "311_service_requests" s LIMIT 2', + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["data"]) == 2 + async def test_query_with_cache(client, manifest_str, connection_info): # add random timestamp to the query to ensure cache is not hit @@ -598,3 +689,243 @@ async def test_cache_ignores_irrelevant_headers(client, manifest_str, connection assert ( response2.headers["X-Cache-Hit"] == "true" ) # Should hit cache despite different irrelevant headers + + +async def test_metadata_list_schemas(client, project_connection_info): + response = await client.post( + url=f"{base_url}/metadata/schemas", + json={ + "connectionInfo": project_connection_info, + }, + ) + assert response.status_code == 200 + result = response.json() + assert isinstance(result, list) + # one project + assert len(result) == 1 + # multiple datasets + assert len(result[0]["schemas"]) > 1 + assert any(schema == "tpch_tiny" for schema in result[0]["schemas"]) + assert any(schema == "tpch_sf1" for schema in result[0]["schemas"]) + assert len(result[0]["schemas"]) > 2 + + response = await client.post( + url=f"{base_url}/metadata/schemas", + json={ + "connectionInfo": { + "bigquery_type": "project", + "billing_project_id": os.getenv("TEST_BIG_QUERY_PROJECT_ID"), + "region": "US", + "credentials": os.getenv("TEST_BIG_QUERY_CREDENTIALS_BASE64_JSON"), + }, + "filterInfo": { + "projects": [{"projectId": BIGQUERY_PUBLIC_DATASET_PROJECT_ID}] + }, + }, + ) + + result = response.json() + assert len(result) == 2 + # multiple datasets + public_project = next( + project + for project in result + if project["name"] == BIGQUERY_PUBLIC_DATASET_PROJECT_ID + ) + assert any(schema == "austin_311" for schema in public_project["schemas"]) + + +async def test_metadata_list_tables(client, project_connection_info): + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": project_connection_info, + "filterInfo": { + "projects": [ + { + "projectId": "wrenai", + "datasetIds": ["tpch_tiny"], + } + ] + }, + }, + ) + + assert response.status_code == 200 + assert len(response.json()) == 8 + table_name = response.json()[0]["name"] + assert len(split(r"\.", table_name)) == 1 # no catalog and schema in table name + + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": project_connection_info, + "filterInfo": { + "projects": [ + { + "projectId": "wrenai", + "datasetIds": ["tpch_tiny", "tpch_sf1"], + } + ] + }, + }, + ) + + assert response.status_code == 200 + assert len(response.json()) == 16 + table_name = response.json()[0]["name"] + assert len(split(r"\.", table_name)) == 2 # no catalog in table name + + +async def test_metadata_list_public_dataset_tables(client): + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": { + "bigquery_type": "project", + "billing_project_id": os.getenv("TEST_BIG_QUERY_PROJECT_ID"), + "region": "asia-east1", + "credentials": os.getenv("TEST_BIG_QUERY_CREDENTIALS_BASE64_JSON"), + }, + "filterInfo": { + "projects": [ + { + "projectId": BIGQUERY_PUBLIC_DATASET_PROJECT_ID, + "datasetIds": ["austin_311"], + } + ] + }, + }, + ) + + assert response.status_code == 422 + assert ( + f"Dataset {BIGQUERY_PUBLIC_DATASET_PROJECT_ID}.austin_311 is in region us, which does not match the connection region asia-east1." + in response.text + ) + + us_connection_info = { + "bigquery_type": "project", + "billing_project_id": os.getenv("TEST_BIG_QUERY_PROJECT_ID"), + "region": "US", + "credentials": os.getenv("TEST_BIG_QUERY_CREDENTIALS_BASE64_JSON"), + } + + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": us_connection_info, + "filterInfo": { + "projects": [ + { + "projectId": BIGQUERY_PUBLIC_DATASET_PROJECT_ID, + "datasetIds": ["austin_311"], + } + ] + }, + }, + ) + + assert response.status_code == 200 + assert len(response.json()) > 0 + + +async def test_metadata_list_tables_missing_field(client, project_connection_info): + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": project_connection_info, + "filterInfo": {"projects": []}, + }, + ) + + assert response.status_code == 422 + assert "At least one project and dataset must be specified" in response.text + + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": project_connection_info, + "filterInfo": { + "projects": [ + { + "projectId": "", + } + ], + }, + }, + ) + + assert response.status_code == 422 + assert "project id should not be empty" in response.text + + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": project_connection_info, + "filterInfo": { + "projects": [ + { + "projectId": "", + } + ], + }, + }, + ) + + assert response.status_code == 422 + assert "project id should not be empty" in response.text + + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": project_connection_info, + "filterInfo": { + "projects": [ + { + "projectId": "wrenai", + "datasetIds": [], + } + ], + }, + }, + ) + + assert response.status_code == 422 + assert "dataset ids should not be empty" in response.text + + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": project_connection_info, + "filterInfo": { + "projects": [ + { + "projectId": "wrenai", + "datasetIds": [""], + } + ], + }, + }, + ) + + assert response.status_code == 422 + assert "dataset id should not be empty" in response.text + + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": project_connection_info, + "filterInfo": { + "projects": [ + { + "projectId": "wrenai", + "datasetIds": ["tpch_sf1", ""], + } + ], + }, + }, + ) + + assert response.status_code == 422 + assert "dataset id should not be empty" in response.text diff --git a/ibis-server/tools/query_local_run.py b/ibis-server/tools/query_local_run.py index dc69981c1..7b80516c4 100644 --- a/ibis-server/tools/query_local_run.py +++ b/ibis-server/tools/query_local_run.py @@ -15,7 +15,8 @@ import json import os from app.custom_sqlglot.dialects.wren import Wren -from app.model import MSSqlConnectionInfo, MySqlConnectionInfo, OracleConnectionInfo, PostgresConnectionInfo, SnowflakeConnectionInfo +from app.model import BigQueryDatasetConnectionInfo, MSSqlConnectionInfo, MySqlConnectionInfo, OracleConnectionInfo, PostgresConnectionInfo, SnowflakeConnectionInfo +from app.model.connector import BigQueryConnector from app.util import to_json import sqlglot import sys @@ -23,7 +24,6 @@ from dotenv import load_dotenv from wren_core import SessionContext -from app.model.data_source import BigQueryConnectionInfo from app.model.data_source import DataSourceExtension import wren_core @@ -91,8 +91,8 @@ print("#") if data_source == "bigquery": - connection_info = BigQueryConnectionInfo.model_validate_json(json.dumps(connection_info)) - connection = DataSourceExtension.get_bigquery_connection(connection_info) + connection_info = BigQueryDatasetConnectionInfo.model_validate_json(json.dumps(connection_info)) + connection = BigQueryConnector(connection_info) elif data_source == "mysql": connection_info = MySqlConnectionInfo.model_validate_json(json.dumps(connection_info)) connection = DataSourceExtension.get_mysql_connection(connection_info)