Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 81 additions & 14 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import time
from contextlib import closing, suppress
from decimal import Decimal as PyDecimal
from functools import cache
from json import loads
from typing import Any
Expand All @@ -19,6 +20,9 @@
from google.oauth2 import service_account
from ibis import BaseBackend
from ibis.backends.sql.compilers.postgres import compiler as postgres_compiler
from ibis.expr.datatypes import Decimal
from ibis.expr.datatypes.core import UUID
from ibis.expr.types import Table
from loguru import logger
from opentelemetry import trace

Expand All @@ -35,7 +39,6 @@
)
from app.model.data_source import DataSource
from app.model.utils import init_duckdb_gcs, init_duckdb_minio, init_duckdb_s3
from app.util import round_decimal_columns

# Override datatypes of ibis
importlib.import_module("app.custom_ibis.backends.sql.datatypes")
Expand Down Expand Up @@ -115,9 +118,40 @@ def query(self, sql: str, limit: int | None = None) -> pa.Table:
ibis_table = self.connection.sql(sql)
if limit is not None:
ibis_table = ibis_table.limit(limit)
ibis_table = round_decimal_columns(ibis_table)
ibis_table = self._handle_pyarrow_unsupported_type(ibis_table)
return ibis_table.to_pyarrow()

def _handle_pyarrow_unsupported_type(self, ibis_table: Table, **kwargs) -> Table:
result_table = ibis_table
for name, dtype in ibis_table.schema().items():
if isinstance(dtype, Decimal):
# Round decimal columns to a specified scale
result_table = self._round_decimal_columns(
result_table=result_table, col_name=name, **kwargs
)
elif isinstance(dtype, UUID):
# Convert UUID to string for compatibility
result_table = self._cast_uuid_columns(
result_table=result_table, col_name=name
)

return result_table

def _cast_uuid_columns(self, result_table: Table, col_name: str) -> Table:
col = result_table[col_name]
# Convert UUID to string for compatibility
casted_col = col.cast("string")
return result_table.mutate(**{col_name: casted_col})

def _round_decimal_columns(
self, result_table: Table, col_name, scale: int = 9
) -> Table:
col = result_table[col_name]
# Maximum precision for pyarrow decimal is 38
decimal_type = Decimal(precision=38, scale=scale)
rounded_col = col.cast(decimal_type).round(scale)
return result_table.mutate(**{col_name: rounded_col})

@tracer.start_as_current_span("connector_dry_run", kind=trace.SpanKind.CLIENT)
def dry_run(self, sql: str) -> None:
self.connection.sql(sql)
Expand Down Expand Up @@ -199,17 +233,34 @@ def __init__(self, connection_info: ConnectionInfo):

@tracer.start_as_current_span("connector_query", kind=trace.SpanKind.CLIENT)
def query(self, sql: str, limit: int | None = None) -> pa.Table:
try:
return super().query(sql, limit)
except Exception as e:
# To descirbe the query result, ibis will wrap the query with a subquery. MSSQL doesn't
# allow order by without limit in a subquery, so we need to handle this error and provide a more user-friendly error message.
# error code 1033: https://learn.microsoft.com/zh-tw/sql/relational-databases/errors-events/database-engine-events-and-errors-1000-to-1999?view=sql-server-ver15
if "(1033)" in e.args[1]:
raise GenericUserError(
"The query with order-by requires a specific limit to be set in MSSQL."
)
raise
ibis_table = self.connection.sql(sql)
if limit is not None:
ibis_table = ibis_table.limit(limit)
return self._round_decimal_columns(ibis_table)

def _round_decimal_columns(self, ibis_table: Table, scale: int = 9) -> pa.Table:
def round_decimal(val):
if val is None:
return None
d = PyDecimal(str(val))
quant = PyDecimal("1." + "0" * scale)
return d.quantize(quant)

decimal_columns = []
for name, dtype in ibis_table.schema().items():
if isinstance(dtype, Decimal):
decimal_columns.append(name)

# If no decimal columns, return original table unchanged
if not decimal_columns:
return ibis_table.to_pyarrow()

pandas_df = ibis_table.to_pandas()
for col_name in decimal_columns:
pandas_df[col_name] = pandas_df[col_name].apply(round_decimal)

arrow_table = pa.Table.from_pandas(pandas_df)
return arrow_table

def dry_run(self, sql: str) -> None:
try:
Expand Down Expand Up @@ -245,9 +296,25 @@ def query(self, sql: str, limit: int | None = None) -> pa.Table:
ibis_table = self.connection.sql(sql, schema=schema)
if limit is not None:
ibis_table = ibis_table.limit(limit)
ibis_table = round_decimal_columns(ibis_table)
ibis_table = self._handle_pyarrow_unsupported_type(ibis_table)
return ibis_table.to_pyarrow()

def _handle_pyarrow_unsupported_type(self, ibis_table: Table, **kwargs) -> Table:
result_table = ibis_table
for name, dtype in ibis_table.schema().items():
if isinstance(dtype, Decimal):
# Round decimal columns to a specified scale
result_table = self._round_decimal_columns(
result_table=result_table, col_name=name, **kwargs
)
elif isinstance(dtype, UUID):
# Convert UUID to string for compatibility
result_table = self._cast_uuid_columns(
result_table=result_table, col_name=name
)

return result_table

@tracer.start_as_current_span("connector_dry_run", kind=trace.SpanKind.CLIENT)
def dry_run(self, sql: str) -> Any:
# Canner enterprise does not support dry-run, so we have to query with limit zero
Expand Down
6 changes: 5 additions & 1 deletion ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,13 @@ def get_dto_type(self):
raise NotImplementedError(f"Unsupported data source: {self}")

def get_connection_info(
self, data: dict[str, Any] | ConnectionInfo, headers: dict[str, str]
self,
data: dict[str, Any] | ConnectionInfo,
headers: dict[str, str] | None = None,
) -> ConnectionInfo:
"""Build a ConnectionInfo object from the provided data and add requried configuration from headers."""

headers = headers or {}
if isinstance(data, ConnectionInfo):
info = data
else:
Expand Down
16 changes: 0 additions & 16 deletions ibis-server/app/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import pyarrow as pa
import wren_core
from fastapi import Header
from ibis.expr.datatypes import Decimal
from ibis.expr.types import Table
from loguru import logger
from opentelemetry import trace
from opentelemetry.baggage.propagation import W3CBaggagePropagator
Expand Down Expand Up @@ -195,20 +193,6 @@ def pd_to_arrow_schema(df: pd.DataFrame) -> pa.Schema:
return pa.schema(fields)


def round_decimal_columns(ibis_table: Table, scale: int = 9) -> Table:
fields = []
for name, dtype in ibis_table.schema().items():
col = ibis_table[name]
if isinstance(dtype, Decimal):
# maxinum precision for pyarrow decimal is 38
decimal_type = Decimal(precision=38, scale=scale)
col = col.cast(decimal_type).round(scale)
fields.append(col.name(name))
else:
fields.append(col)
return ibis_table.select(*fields)


def update_response_headers(response, required_headers: dict):
if X_CACHE_HIT in required_headers:
response.headers[X_CACHE_HIT] = required_headers[X_CACHE_HIT]
Expand Down
73 changes: 67 additions & 6 deletions ibis-server/tests/routers/v2/connector/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ def mssql(request) -> SqlServerContainer:
)
)

conn.execute(text("CREATE TABLE uuid_test (order_uuid uniqueidentifier)"))
conn.execute(
text(
"INSERT INTO uuid_test (order_uuid) VALUES (cast('123e4567-e89b-12d3-a456-426614174000' as uniqueidentifier))"
)
)

request.addfinalizer(mssql.stop)
return mssql

Expand Down Expand Up @@ -491,21 +498,75 @@ async def test_order_by_nulls_last(client, manifest_str, mssql: SqlServerContain
assert result["data"][2][0] == "three"


async def test_order_by_require_limit(client, manifest_str, mssql: SqlServerContainer):
async def test_order_by_without_limit(client, manifest_str, mssql: SqlServerContainer):
connection_info = _to_connection_info(mssql)
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT letter FROM "null_test" ORDER BY id NULLS LAST',
"sql": 'SELECT letter FROM "null_test" ORDER BY id',
},
)
assert response.status_code == 422
assert (
"The query with order-by requires a specific limit to be set in MSSQL."
in response.text
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 3
assert result["data"][0][0] == "one"
assert result["data"][1][0] == "two"
assert result["data"][2][0] == "three"


# we dont give the expression a alias on purpose
async def test_decimal_precision(client, manifest_str, mssql: SqlServerContainer):
connection_info = _to_connection_info(mssql)
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT cast(1 as decimal(38, 8)) / cast(3 as decimal(38, 8))",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1
assert result["data"][0][0] == "0.333333"


async def test_uuid_type(client, mssql: SqlServerContainer):
connection_info = _to_connection_info(mssql)
manifest = {
"catalog": "my_catalog",
"schema": "my_schema",
"models": [
{
"name": "uuid_test",
"tableReference": {
"schema": "dbo",
"table": "uuid_test",
},
"columns": [
{"name": "order_uuid", "type": "uuid"},
],
},
],
}
manifest_str = base64.b64encode(orjson.dumps(manifest)).decode("utf-8")
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "select order_uuid from uuid_test",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1
assert result["data"][0][0] == "123E4567-E89B-12D3-A456-426614174000"
assert result["dtypes"] == {
"order_uuid": "string",
}


def _to_connection_info(mssql: SqlServerContainer):
Expand Down
16 changes: 16 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,22 @@ async def test_connection_timeout(client, manifest_str, postgres: PostgresContai
)


async def test_uuid_type(client, manifest_str, postgres: PostgresContainer):
connection_info = _to_connection_info(postgres)
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "select '123e4567-e89b-12d3-a456-426614174000'::uuid as order_uuid",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1
assert result["data"][0][0] == "123e4567-e89b-12d3-a456-426614174000"


async def test_order_by_nulls_last(client, manifest_str, postgres: PostgresContainer):
connection_info = _to_connection_info(postgres)
response = await client.post(
Expand Down
2 changes: 1 addition & 1 deletion ibis-server/wren/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def main():
# Otherwise, we can directly use the connection_info as is.
if "type" in connection_info:
connection_info = data_source.get_connection_info(
connection_info["properties"]
connection_info["properties"],
)
else:
connection_info = data_source.get_connection_info(connection_info)
Expand Down