diff --git a/ibis-server/.pre-commit-config.yaml b/ibis-server/.pre-commit-config.yaml index 8a365834e..c47ba0032 100644 --- a/ibis-server/.pre-commit-config.yaml +++ b/ibis-server/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.4.8 + rev: v0.12.0 hooks: # Run the linter. - id: ruff diff --git a/ibis-server/app/model/connector.py b/ibis-server/app/model/connector.py index 08a6604cc..39864acf8 100644 --- a/ibis-server/app/model/connector.py +++ b/ibis-server/app/model/connector.py @@ -8,7 +8,6 @@ import ibis import ibis.expr.datatypes as dt import ibis.expr.schema as sch -import ibis.formats import pandas as pd import pyarrow as pa import sqlglot.expressions as sge @@ -32,6 +31,7 @@ ) from app.model.data_source import DataSource from app.model.utils import init_duckdb_gcs, init_duckdb_minio, init_duckdb_s3 +from app.util import round_decimal_columns # Override datatypes of ibis importlib.import_module("app.custom_ibis.backends.sql.datatypes") @@ -77,7 +77,9 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo): @tracer.start_as_current_span("connector_query", kind=trace.SpanKind.CLIENT) def query(self, sql: str, limit: int) -> pa.Table: - return self.connection.sql(sql).limit(limit).to_pyarrow() + ibis_table = self.connection.sql(sql).limit(limit) + ibis_table = round_decimal_columns(ibis_table) + return ibis_table.to_pyarrow() @tracer.start_as_current_span("connector_dry_run", kind=trace.SpanKind.CLIENT) def dry_run(self, sql: str) -> None: @@ -116,10 +118,12 @@ def __init__(self, connection_info: ConnectionInfo): self.connection = DataSource.canner.get_connection(connection_info) @tracer.start_as_current_span("connector_query", kind=trace.SpanKind.CLIENT) - def query(self, sql: str, limit: int) -> pd.DataFrame: + def query(self, sql: str, limit: int) -> pa.Table: # Canner enterprise does not support `CREATE TEMPORARY VIEW` for getting schema schema = self._get_schema(sql) - return self.connection.sql(sql, schema=schema).limit(limit).to_pyarrow() + ibis_table = self.connection.sql(sql, schema=schema).limit(limit) + ibis_table = round_decimal_columns(ibis_table) + return ibis_table.to_pyarrow() @tracer.start_as_current_span("connector_dry_run", kind=trace.SpanKind.CLIENT) def dry_run(self, sql: str) -> Any: diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index f1a6ab201..7538aa0c4 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -6,6 +6,8 @@ import pyarrow as pa import wren_core from fastapi import Header +from ibis.expr.datatypes import Decimal +from ibis.expr.types import Table from loguru import logger from opentelemetry import trace from opentelemetry.baggage.propagation import W3CBaggagePropagator @@ -186,6 +188,20 @@ def pd_to_arrow_schema(df: pd.DataFrame) -> pa.Schema: return pa.schema(fields) +def round_decimal_columns(ibis_table: Table, scale: int = 9) -> Table: + fields = [] + for name, dtype in ibis_table.schema().items(): + col = ibis_table[name] + if isinstance(dtype, Decimal): + # maxinum precision for pyarrow decimal is 38 + decimal_type = Decimal(precision=38, scale=scale) + col = col.cast(decimal_type).round(scale) + fields.append(col.name(name)) + else: + fields.append(col) + return ibis_table.select(*fields) + + def update_response_headers(response, required_headers: dict): if X_CACHE_HIT in required_headers: response.headers[X_CACHE_HIT] = required_headers[X_CACHE_HIT] diff --git a/ibis-server/tests/routers/v2/connector/test_clickhouse.py b/ibis-server/tests/routers/v2/connector/test_clickhouse.py index d6cb83edb..fe5f5266d 100644 --- a/ibis-server/tests/routers/v2/connector/test_clickhouse.py +++ b/ibis-server/tests/routers/v2/connector/test_clickhouse.py @@ -191,7 +191,7 @@ async def test_query(client, manifest_str, clickhouse: ClickHouseContainer): "orderkey": "int32", "custkey": "int32", "orderstatus": "string", - "totalprice": "decimal128(15, 2)", + "totalprice": "decimal128(38, 9)", "orderdate": "date32[day]", "order_cust_key": "string", "timestamp": "timestamp[ns]", @@ -310,7 +310,7 @@ async def test_query_to_many_relationship( assert len(result["data"]) == 1 assert result["data"][0] == ["2860895.79"] assert result["dtypes"] == { - "totalprice": "decimal128(38, 2)", + "totalprice": "decimal128(38, 9)", } diff --git a/ibis-server/tests/routers/v2/connector/test_oracle.py b/ibis-server/tests/routers/v2/connector/test_oracle.py index 5fecc6858..d1e5e6a46 100644 --- a/ibis-server/tests/routers/v2/connector/test_oracle.py +++ b/ibis-server/tests/routers/v2/connector/test_oracle.py @@ -202,7 +202,7 @@ async def test_query(client, manifest_str, oracle: OracleDbContainer): "orderkey": "int64", "custkey": "int64", "orderstatus": "string", - "totalprice": "decimal128(38, 2)", + "totalprice": "decimal128(38, 9)", "orderdate": "date32[day]", "order_cust_key": "string", "timestamp": "timestamp[ns]", diff --git a/ibis-server/tests/routers/v2/connector/test_postgres.py b/ibis-server/tests/routers/v2/connector/test_postgres.py index 30ec6b31a..9501b972e 100644 --- a/ibis-server/tests/routers/v2/connector/test_postgres.py +++ b/ibis-server/tests/routers/v2/connector/test_postgres.py @@ -1010,6 +1010,22 @@ async def test_postgis_geometry(client, manifest_str, postgis: PostgresContainer assert result["data"][0] == [74.66265347816136] +async def test_decimal_precision(client, manifest_str, postgres: PostgresContainer): + connection_info = _to_connection_info(postgres) + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT cast(1 as decimal(38, 8)) / cast(3 as decimal(38, 8)) as result", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + assert result["data"][0][0] == "0.333333333" + + def _to_connection_info(pg: PostgresContainer): return { "host": pg.get_container_host_ip(), diff --git a/ibis-server/tests/routers/v2/connector/test_s3_file.py b/ibis-server/tests/routers/v2/connector/test_s3_file.py index 4e2603ae6..7f2b2be98 100644 --- a/ibis-server/tests/routers/v2/connector/test_s3_file.py +++ b/ibis-server/tests/routers/v2/connector/test_s3_file.py @@ -163,7 +163,7 @@ async def test_query_calculated_field(client, manifest_str, connection_info): ] assert result["dtypes"] == { "custkey": "int32", - "sum_totalprice": "decimal128(38, 2)", + "sum_totalprice": "decimal128(38, 9)", } diff --git a/ibis-server/tests/routers/v3/connector/bigquery/test_query.py b/ibis-server/tests/routers/v3/connector/bigquery/test_query.py index e353bea9f..42622bb19 100644 --- a/ibis-server/tests/routers/v3/connector/bigquery/test_query.py +++ b/ibis-server/tests/routers/v3/connector/bigquery/test_query.py @@ -321,3 +321,18 @@ async def test_timestamp_func(client, manifest_str, connection_info): assert result["dtypes"] == { "compare": "bool", } + + +async def test_decimal_precision(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT cast(1 as decimal(38, 8)) / cast(3 as decimal(38, 8)) as result", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + assert result["data"][0][0] == "0.333333333" diff --git a/ibis-server/tests/routers/v3/connector/oracle/test_query.py b/ibis-server/tests/routers/v3/connector/oracle/test_query.py index f79be229f..f7982ba60 100644 --- a/ibis-server/tests/routers/v3/connector/oracle/test_query.py +++ b/ibis-server/tests/routers/v3/connector/oracle/test_query.py @@ -104,7 +104,7 @@ async def test_query(client, manifest_str, connection_info): "orderkey": "int64", "custkey": "int64", "orderstatus": "string", - "totalprice": "decimal128(38, 2)", + "totalprice": "decimal128(38, 9)", "orderdate": "date32[day]", "order_cust_key": "string", "timestamp": "timestamp[ns]", @@ -178,7 +178,7 @@ async def test_query_number_scale(client, connection_info): { "name": "id_p_s", "expression": '"ID_P_S"', - "type": "decimal128(10, 2)", + "type": "decimal128(38, 9)", }, ], "primaryKey": "id", @@ -207,5 +207,5 @@ async def test_query_number_scale(client, connection_info): assert result["dtypes"] == { "id": "int64", "id_p": "int64", - "id_p_s": "decimal128(10, 2)", + "id_p_s": "decimal128(38, 9)", } diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_query.py b/ibis-server/tests/routers/v3/connector/postgres/test_query.py index 144199b57..1356d66f1 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_query.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_query.py @@ -655,8 +655,7 @@ async def test_format_floating(client, manifest_str, connection_info): CAST(1.23e4 AS DECIMAL(10,5)) AS case_cast_decimal, CAST(1.234e+14 AS DECIMAL(20,0)) AS show_float, CAST(1.234e+15 AS DECIMAL(20,0)) AS show_exponent, - CAST(1.123456789 AS DECIMAL(20,9)) AS round_to_9_decimal_places, - CAST(0.123456789123456789 AS DECIMAL(20,18)) AS round_to_18_decimal_places + CAST(1.123456789 AS DECIMAL(20,9)) AS round_to_9_decimal_places """, }, headers={ @@ -695,5 +694,19 @@ async def test_format_floating(client, manifest_str, connection_info): # DataFusion does not support it, so we show the full number "1234000000000000.0", # show_exponent "1.123456789", # round_to_9_decimal_places - "0.12345678912345678", # round_to_18_decimal_places ] + + +async def test_decimal_precision(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT cast(1 as decimal(38, 8)) / cast(3 as decimal(38, 8)) as result", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + assert result["data"][0][0] == "0.333333333"