diff --git a/ibis-server/app/model/connector.py b/ibis-server/app/model/connector.py index 7707a0650..b4dcee984 100644 --- a/ibis-server/app/model/connector.py +++ b/ibis-server/app/model/connector.py @@ -3,6 +3,7 @@ import os import time from contextlib import closing, suppress +from decimal import Decimal as PyDecimal from functools import cache from json import loads from typing import Any @@ -19,6 +20,9 @@ from google.oauth2 import service_account from ibis import BaseBackend from ibis.backends.sql.compilers.postgres import compiler as postgres_compiler +from ibis.expr.datatypes import Decimal +from ibis.expr.datatypes.core import UUID +from ibis.expr.types import Table from loguru import logger from opentelemetry import trace @@ -35,7 +39,6 @@ ) from app.model.data_source import DataSource from app.model.utils import init_duckdb_gcs, init_duckdb_minio, init_duckdb_s3 -from app.util import round_decimal_columns # Override datatypes of ibis importlib.import_module("app.custom_ibis.backends.sql.datatypes") @@ -115,9 +118,40 @@ def query(self, sql: str, limit: int | None = None) -> pa.Table: ibis_table = self.connection.sql(sql) if limit is not None: ibis_table = ibis_table.limit(limit) - ibis_table = round_decimal_columns(ibis_table) + ibis_table = self._handle_pyarrow_unsupported_type(ibis_table) return ibis_table.to_pyarrow() + def _handle_pyarrow_unsupported_type(self, ibis_table: Table, **kwargs) -> Table: + result_table = ibis_table + for name, dtype in ibis_table.schema().items(): + if isinstance(dtype, Decimal): + # Round decimal columns to a specified scale + result_table = self._round_decimal_columns( + result_table=result_table, col_name=name, **kwargs + ) + elif isinstance(dtype, UUID): + # Convert UUID to string for compatibility + result_table = self._cast_uuid_columns( + result_table=result_table, col_name=name + ) + + return result_table + + def _cast_uuid_columns(self, result_table: Table, col_name: str) -> Table: + col = result_table[col_name] + # Convert UUID to string for compatibility + casted_col = col.cast("string") + return result_table.mutate(**{col_name: casted_col}) + + def _round_decimal_columns( + self, result_table: Table, col_name, scale: int = 9 + ) -> Table: + col = result_table[col_name] + # Maximum precision for pyarrow decimal is 38 + decimal_type = Decimal(precision=38, scale=scale) + rounded_col = col.cast(decimal_type).round(scale) + return result_table.mutate(**{col_name: rounded_col}) + @tracer.start_as_current_span("connector_dry_run", kind=trace.SpanKind.CLIENT) def dry_run(self, sql: str) -> None: self.connection.sql(sql) @@ -199,17 +233,34 @@ def __init__(self, connection_info: ConnectionInfo): @tracer.start_as_current_span("connector_query", kind=trace.SpanKind.CLIENT) def query(self, sql: str, limit: int | None = None) -> pa.Table: - try: - return super().query(sql, limit) - except Exception as e: - # To descirbe the query result, ibis will wrap the query with a subquery. MSSQL doesn't - # allow order by without limit in a subquery, so we need to handle this error and provide a more user-friendly error message. - # error code 1033: https://learn.microsoft.com/zh-tw/sql/relational-databases/errors-events/database-engine-events-and-errors-1000-to-1999?view=sql-server-ver15 - if "(1033)" in e.args[1]: - raise GenericUserError( - "The query with order-by requires a specific limit to be set in MSSQL." - ) - raise + ibis_table = self.connection.sql(sql) + if limit is not None: + ibis_table = ibis_table.limit(limit) + return self._round_decimal_columns(ibis_table) + + def _round_decimal_columns(self, ibis_table: Table, scale: int = 9) -> pa.Table: + def round_decimal(val): + if val is None: + return None + d = PyDecimal(str(val)) + quant = PyDecimal("1." + "0" * scale) + return d.quantize(quant) + + decimal_columns = [] + for name, dtype in ibis_table.schema().items(): + if isinstance(dtype, Decimal): + decimal_columns.append(name) + + # If no decimal columns, return original table unchanged + if not decimal_columns: + return ibis_table.to_pyarrow() + + pandas_df = ibis_table.to_pandas() + for col_name in decimal_columns: + pandas_df[col_name] = pandas_df[col_name].apply(round_decimal) + + arrow_table = pa.Table.from_pandas(pandas_df) + return arrow_table def dry_run(self, sql: str) -> None: try: @@ -245,9 +296,25 @@ def query(self, sql: str, limit: int | None = None) -> pa.Table: ibis_table = self.connection.sql(sql, schema=schema) if limit is not None: ibis_table = ibis_table.limit(limit) - ibis_table = round_decimal_columns(ibis_table) + ibis_table = self._handle_pyarrow_unsupported_type(ibis_table) return ibis_table.to_pyarrow() + def _handle_pyarrow_unsupported_type(self, ibis_table: Table, **kwargs) -> Table: + result_table = ibis_table + for name, dtype in ibis_table.schema().items(): + if isinstance(dtype, Decimal): + # Round decimal columns to a specified scale + result_table = self._round_decimal_columns( + result_table=result_table, col_name=name, **kwargs + ) + elif isinstance(dtype, UUID): + # Convert UUID to string for compatibility + result_table = self._cast_uuid_columns( + result_table=result_table, col_name=name + ) + + return result_table + @tracer.start_as_current_span("connector_dry_run", kind=trace.SpanKind.CLIENT) def dry_run(self, sql: str) -> Any: # Canner enterprise does not support dry-run, so we have to query with limit zero diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index b4972f753..c48328dd7 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -83,9 +83,13 @@ def get_dto_type(self): raise NotImplementedError(f"Unsupported data source: {self}") def get_connection_info( - self, data: dict[str, Any] | ConnectionInfo, headers: dict[str, str] + self, + data: dict[str, Any] | ConnectionInfo, + headers: dict[str, str] | None = None, ) -> ConnectionInfo: """Build a ConnectionInfo object from the provided data and add requried configuration from headers.""" + + headers = headers or {} if isinstance(data, ConnectionInfo): info = data else: diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index ca4ca801b..fa4f225de 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -10,8 +10,6 @@ import pyarrow as pa import wren_core from fastapi import Header -from ibis.expr.datatypes import Decimal -from ibis.expr.types import Table from loguru import logger from opentelemetry import trace from opentelemetry.baggage.propagation import W3CBaggagePropagator @@ -195,20 +193,6 @@ def pd_to_arrow_schema(df: pd.DataFrame) -> pa.Schema: return pa.schema(fields) -def round_decimal_columns(ibis_table: Table, scale: int = 9) -> Table: - fields = [] - for name, dtype in ibis_table.schema().items(): - col = ibis_table[name] - if isinstance(dtype, Decimal): - # maxinum precision for pyarrow decimal is 38 - decimal_type = Decimal(precision=38, scale=scale) - col = col.cast(decimal_type).round(scale) - fields.append(col.name(name)) - else: - fields.append(col) - return ibis_table.select(*fields) - - def update_response_headers(response, required_headers: dict): if X_CACHE_HIT in required_headers: response.headers[X_CACHE_HIT] = required_headers[X_CACHE_HIT] diff --git a/ibis-server/tests/routers/v2/connector/test_mssql.py b/ibis-server/tests/routers/v2/connector/test_mssql.py index d9c2c2291..da17b241c 100644 --- a/ibis-server/tests/routers/v2/connector/test_mssql.py +++ b/ibis-server/tests/routers/v2/connector/test_mssql.py @@ -121,6 +121,13 @@ def mssql(request) -> SqlServerContainer: ) ) + conn.execute(text("CREATE TABLE uuid_test (order_uuid uniqueidentifier)")) + conn.execute( + text( + "INSERT INTO uuid_test (order_uuid) VALUES (cast('123e4567-e89b-12d3-a456-426614174000' as uniqueidentifier))" + ) + ) + request.addfinalizer(mssql.stop) return mssql @@ -491,21 +498,75 @@ async def test_order_by_nulls_last(client, manifest_str, mssql: SqlServerContain assert result["data"][2][0] == "three" -async def test_order_by_require_limit(client, manifest_str, mssql: SqlServerContainer): +async def test_order_by_without_limit(client, manifest_str, mssql: SqlServerContainer): connection_info = _to_connection_info(mssql) response = await client.post( url=f"{base_url}/query", json={ "connectionInfo": connection_info, "manifestStr": manifest_str, - "sql": 'SELECT letter FROM "null_test" ORDER BY id NULLS LAST', + "sql": 'SELECT letter FROM "null_test" ORDER BY id', }, ) - assert response.status_code == 422 - assert ( - "The query with order-by requires a specific limit to be set in MSSQL." - in response.text + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 3 + assert result["data"][0][0] == "one" + assert result["data"][1][0] == "two" + assert result["data"][2][0] == "three" + + +# we dont give the expression a alias on purpose +async def test_decimal_precision(client, manifest_str, mssql: SqlServerContainer): + connection_info = _to_connection_info(mssql) + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT cast(1 as decimal(38, 8)) / cast(3 as decimal(38, 8))", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + assert result["data"][0][0] == "0.333333" + + +async def test_uuid_type(client, mssql: SqlServerContainer): + connection_info = _to_connection_info(mssql) + manifest = { + "catalog": "my_catalog", + "schema": "my_schema", + "models": [ + { + "name": "uuid_test", + "tableReference": { + "schema": "dbo", + "table": "uuid_test", + }, + "columns": [ + {"name": "order_uuid", "type": "uuid"}, + ], + }, + ], + } + manifest_str = base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "select order_uuid from uuid_test", + }, ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + assert result["data"][0][0] == "123E4567-E89B-12D3-A456-426614174000" + assert result["dtypes"] == { + "order_uuid": "string", + } def _to_connection_info(mssql: SqlServerContainer): diff --git a/ibis-server/tests/routers/v2/connector/test_postgres.py b/ibis-server/tests/routers/v2/connector/test_postgres.py index b8269f634..ca140bce3 100644 --- a/ibis-server/tests/routers/v2/connector/test_postgres.py +++ b/ibis-server/tests/routers/v2/connector/test_postgres.py @@ -1046,6 +1046,22 @@ async def test_connection_timeout(client, manifest_str, postgres: PostgresContai ) +async def test_uuid_type(client, manifest_str, postgres: PostgresContainer): + connection_info = _to_connection_info(postgres) + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "select '123e4567-e89b-12d3-a456-426614174000'::uuid as order_uuid", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + assert result["data"][0][0] == "123e4567-e89b-12d3-a456-426614174000" + + async def test_order_by_nulls_last(client, manifest_str, postgres: PostgresContainer): connection_info = _to_connection_info(postgres) response = await client.post( diff --git a/ibis-server/wren/__main__.py b/ibis-server/wren/__main__.py index 0b3bc0939..cabf26f8b 100644 --- a/ibis-server/wren/__main__.py +++ b/ibis-server/wren/__main__.py @@ -34,7 +34,7 @@ def main(): # Otherwise, we can directly use the connection_info as is. if "type" in connection_info: connection_info = data_source.get_connection_info( - connection_info["properties"] + connection_info["properties"], ) else: connection_info = data_source.get_connection_info(connection_info)