Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ibis-server/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.8
rev: v0.12.0
hooks:
# Run the linter.
- id: ruff
Expand Down
12 changes: 8 additions & 4 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import ibis
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
import ibis.formats
import pandas as pd
import pyarrow as pa
import sqlglot.expressions as sge
Expand All @@ -32,6 +31,7 @@
)
from app.model.data_source import DataSource
from app.model.utils import init_duckdb_gcs, init_duckdb_minio, init_duckdb_s3
from app.util import round_decimal_columns

# Override datatypes of ibis
importlib.import_module("app.custom_ibis.backends.sql.datatypes")
Expand Down Expand Up @@ -77,7 +77,9 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo):

@tracer.start_as_current_span("connector_query", kind=trace.SpanKind.CLIENT)
def query(self, sql: str, limit: int) -> pa.Table:
return self.connection.sql(sql).limit(limit).to_pyarrow()
ibis_table = self.connection.sql(sql).limit(limit)
ibis_table = round_decimal_columns(ibis_table)
return ibis_table.to_pyarrow()

@tracer.start_as_current_span("connector_dry_run", kind=trace.SpanKind.CLIENT)
def dry_run(self, sql: str) -> None:
Expand Down Expand Up @@ -116,10 +118,12 @@ def __init__(self, connection_info: ConnectionInfo):
self.connection = DataSource.canner.get_connection(connection_info)

@tracer.start_as_current_span("connector_query", kind=trace.SpanKind.CLIENT)
def query(self, sql: str, limit: int) -> pd.DataFrame:
def query(self, sql: str, limit: int) -> pa.Table:
# Canner enterprise does not support `CREATE TEMPORARY VIEW` for getting schema
schema = self._get_schema(sql)
return self.connection.sql(sql, schema=schema).limit(limit).to_pyarrow()
ibis_table = self.connection.sql(sql, schema=schema).limit(limit)
ibis_table = round_decimal_columns(ibis_table)
return ibis_table.to_pyarrow()

@tracer.start_as_current_span("connector_dry_run", kind=trace.SpanKind.CLIENT)
def dry_run(self, sql: str) -> Any:
Expand Down
16 changes: 16 additions & 0 deletions ibis-server/app/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pyarrow as pa
import wren_core
from fastapi import Header
from ibis.expr.datatypes import Decimal
from ibis.expr.types import Table
from loguru import logger
from opentelemetry import trace
from opentelemetry.baggage.propagation import W3CBaggagePropagator
Expand Down Expand Up @@ -186,6 +188,20 @@ def pd_to_arrow_schema(df: pd.DataFrame) -> pa.Schema:
return pa.schema(fields)


def round_decimal_columns(ibis_table: Table, scale: int = 9) -> Table:
fields = []
for name, dtype in ibis_table.schema().items():
col = ibis_table[name]
if isinstance(dtype, Decimal):
# maxinum precision for pyarrow decimal is 38
decimal_type = Decimal(precision=38, scale=scale)
col = col.cast(decimal_type).round(scale)
fields.append(col.name(name))
else:
fields.append(col)
return ibis_table.select(*fields)


def update_response_headers(response, required_headers: dict):
if X_CACHE_HIT in required_headers:
response.headers[X_CACHE_HIT] = required_headers[X_CACHE_HIT]
Expand Down
4 changes: 2 additions & 2 deletions ibis-server/tests/routers/v2/connector/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ async def test_query(client, manifest_str, clickhouse: ClickHouseContainer):
"orderkey": "int32",
"custkey": "int32",
"orderstatus": "string",
"totalprice": "decimal128(15, 2)",
"totalprice": "decimal128(38, 9)",
"orderdate": "date32[day]",
"order_cust_key": "string",
"timestamp": "timestamp[ns]",
Expand Down Expand Up @@ -310,7 +310,7 @@ async def test_query_to_many_relationship(
assert len(result["data"]) == 1
assert result["data"][0] == ["2860895.79"]
assert result["dtypes"] == {
"totalprice": "decimal128(38, 2)",
"totalprice": "decimal128(38, 9)",
}


Expand Down
2 changes: 1 addition & 1 deletion ibis-server/tests/routers/v2/connector/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def test_query(client, manifest_str, oracle: OracleDbContainer):
"orderkey": "int64",
"custkey": "int64",
"orderstatus": "string",
"totalprice": "decimal128(38, 2)",
"totalprice": "decimal128(38, 9)",
"orderdate": "date32[day]",
"order_cust_key": "string",
"timestamp": "timestamp[ns]",
Expand Down
16 changes: 16 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,22 @@ async def test_postgis_geometry(client, manifest_str, postgis: PostgresContainer
assert result["data"][0] == [74.66265347816136]


async def test_decimal_precision(client, manifest_str, postgres: PostgresContainer):
connection_info = _to_connection_info(postgres)
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT cast(1 as decimal(38, 8)) / cast(3 as decimal(38, 8)) as result",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1
assert result["data"][0][0] == "0.333333333"


def _to_connection_info(pg: PostgresContainer):
return {
"host": pg.get_container_host_ip(),
Expand Down
2 changes: 1 addition & 1 deletion ibis-server/tests/routers/v2/connector/test_s3_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ async def test_query_calculated_field(client, manifest_str, connection_info):
]
assert result["dtypes"] == {
"custkey": "int32",
"sum_totalprice": "decimal128(38, 2)",
"sum_totalprice": "decimal128(38, 9)",
}


Expand Down
15 changes: 15 additions & 0 deletions ibis-server/tests/routers/v3/connector/bigquery/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,18 @@ async def test_timestamp_func(client, manifest_str, connection_info):
assert result["dtypes"] == {
"compare": "bool",
}


async def test_decimal_precision(client, manifest_str, connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT cast(1 as decimal(38, 8)) / cast(3 as decimal(38, 8)) as result",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1
assert result["data"][0][0] == "0.333333333"
6 changes: 3 additions & 3 deletions ibis-server/tests/routers/v3/connector/oracle/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def test_query(client, manifest_str, connection_info):
"orderkey": "int64",
"custkey": "int64",
"orderstatus": "string",
"totalprice": "decimal128(38, 2)",
"totalprice": "decimal128(38, 9)",
"orderdate": "date32[day]",
"order_cust_key": "string",
"timestamp": "timestamp[ns]",
Expand Down Expand Up @@ -178,7 +178,7 @@ async def test_query_number_scale(client, connection_info):
{
"name": "id_p_s",
"expression": '"ID_P_S"',
"type": "decimal128(10, 2)",
"type": "decimal128(38, 9)",
},
],
"primaryKey": "id",
Expand Down Expand Up @@ -207,5 +207,5 @@ async def test_query_number_scale(client, connection_info):
assert result["dtypes"] == {
"id": "int64",
"id_p": "int64",
"id_p_s": "decimal128(10, 2)",
"id_p_s": "decimal128(38, 9)",
}
19 changes: 16 additions & 3 deletions ibis-server/tests/routers/v3/connector/postgres/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,7 @@ async def test_format_floating(client, manifest_str, connection_info):
CAST(1.23e4 AS DECIMAL(10,5)) AS case_cast_decimal,
CAST(1.234e+14 AS DECIMAL(20,0)) AS show_float,
CAST(1.234e+15 AS DECIMAL(20,0)) AS show_exponent,
CAST(1.123456789 AS DECIMAL(20,9)) AS round_to_9_decimal_places,
CAST(0.123456789123456789 AS DECIMAL(20,18)) AS round_to_18_decimal_places
CAST(1.123456789 AS DECIMAL(20,9)) AS round_to_9_decimal_places
""",
},
headers={
Expand Down Expand Up @@ -695,5 +694,19 @@ async def test_format_floating(client, manifest_str, connection_info):
# DataFusion does not support it, so we show the full number
"1234000000000000.0", # show_exponent
"1.123456789", # round_to_9_decimal_places
"0.12345678912345678", # round_to_18_decimal_places
]


async def test_decimal_precision(client, manifest_str, connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT cast(1 as decimal(38, 8)) / cast(3 as decimal(38, 8)) as result",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1
assert result["data"][0][0] == "0.333333333"