Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def query(self, sql: str, limit: int) -> pa.Table:
return super().query(sql, limit)
except ValueError as e:
# Import here to avoid override the custom datatypes
import ibis.backends.bigquery
import ibis.backends.bigquery # noqa: PLC0415

# Try to match the error message from the google cloud bigquery library matching Arrow type error.
# If the error message matches, requries to get the schema from the result and generate a empty pandas dataframe with the mapped schema
Expand Down Expand Up @@ -190,7 +190,7 @@ def query(self, sql: str, limit: int) -> pa.Table:

class DuckDBConnector:
def __init__(self, connection_info: ConnectionInfo):
import duckdb
import duckdb # noqa: PLC0415

self.connection = duckdb.connect()
if isinstance(connection_info, S3FileConnectionInfo):
Expand Down Expand Up @@ -221,7 +221,7 @@ def dry_run(self, sql: str) -> None:

class RedshiftConnector:
def __init__(self, connection_info: RedshiftConnectionUnion):
import redshift_connector
import redshift_connector # noqa: PLC0415

if isinstance(connection_info, RedshiftIAMConnectionInfo):
self.connection = redshift_connector.connect(
Expand Down
3 changes: 1 addition & 2 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import ssl
from enum import Enum, StrEnum, auto
from json import loads
from typing import Optional

import ibis
from google.oauth2 import service_account
Expand Down Expand Up @@ -241,7 +240,7 @@ def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend:
)

@staticmethod
def _create_ssl_context(info: ConnectionInfo) -> Optional[ssl.SSLContext]:
def _create_ssl_context(info: ConnectionInfo) -> ssl.SSLContext | None:
ssl_mode = (
info.ssl_mode.get_secret_value()
if hasattr(info, "ssl_mode") and info.ssl_mode
Expand Down
4 changes: 2 additions & 2 deletions ibis-server/app/query_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import hashlib
import time
from typing import Any, Optional
from typing import Any

import ibis
import opendal
Expand All @@ -17,7 +17,7 @@ def __init__(self, root: str = "/tmp/wren-engine/"):
self.root = root

@tracer.start_as_current_span("get_cache", kind=trace.SpanKind.INTERNAL)
def get(self, data_source: str, sql: str, info) -> Optional[Any]:
def get(self, data_source: str, sql: str, info) -> Any | None:
cache_key = self._generate_cache_key(data_source, sql, info)
cache_file_name = self._get_cache_file_name(cache_key)
op = self._get_dal_operator()
Expand Down
2 changes: 1 addition & 1 deletion ibis-server/app/routers/v2/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ async def query(
# case 5~8 Other cases (cache is not enabled)
case (False, _, _):
pass
response = ORJSONResponse(to_json(result, headers))
response = ORJSONResponse(to_json(result, headers, data_source=data_source))
update_response_headers(response, cache_headers)

if is_fallback:
Expand Down
2 changes: 1 addition & 1 deletion ibis-server/app/routers/v3/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ async def query(
case (False, _, _):
pass

response = ORJSONResponse(to_json(result, headers))
response = ORJSONResponse(to_json(result, headers, data_source=data_source))
update_response_headers(response, cache_headers)
return response
except Exception as e:
Expand Down
116 changes: 73 additions & 43 deletions ibis-server/app/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import base64

import duckdb
import datafusion
import orjson
import pandas as pd
import pyarrow as pa
Expand Down Expand Up @@ -40,7 +40,8 @@ def base64_to_dict(base64_str: str) -> dict:


@tracer.start_as_current_span("to_json", kind=trace.SpanKind.INTERNAL)
def to_json(df: pa.Table, headers: dict) -> dict:
def to_json(df: pa.Table, headers: dict, data_source: DataSource = None) -> dict:
df = _with_session_timezone(df, headers, data_source)
dtypes = {field.name: str(field.type) for field in df.schema}
if df.num_rows == 0:
return {
Expand All @@ -49,46 +50,76 @@ def to_json(df: pa.Table, headers: dict) -> dict:
"dtypes": dtypes,
}

ctx = get_datafusion_context(headers)
ctx.register_record_batches(name="arrow_table", partitions=[df.to_batches()])

formatted_sql = (
"SELECT " + ", ".join([_formater(field) for field in df.schema]) + " FROM df"
"SELECT "
+ ", ".join([_formater(field) for field in df.schema])
+ " FROM arrow_table"
)
logger.debug(f"formmated_sql: {formatted_sql}")
conn = get_duckdb_conn(headers)
formatted_df = conn.execute(formatted_sql).fetch_df()
formatted_df = ctx.sql(formatted_sql).to_pandas()

result = formatted_df.to_dict(orient="split")
result["dtypes"] = dtypes
result.pop("index", None) # Remove index field from the DuckDB result
return result


def get_duckdb_conn(headers: dict) -> duckdb.DuckDBPyConnection:
"""Get a DuckDB connection with the provided headers."""
conn = duckdb.connect()
if X_WREN_TIMEZONE in headers:
timezone = headers[X_WREN_TIMEZONE]
if timezone.startwith("+") or timezone.startswith("-"):
# If the timezone is an offset, convert it to a named timezone
timezone = get_timezone_from_offset(timezone)
conn.execute("SET TimeZone = ?", [timezone])
else:
# Default to UTC if no timezone is provided
conn.execute("SET TimeZone = 'UTC'")
def _with_session_timezone(
df: pa.Table, headers: dict, data_source: DataSource
) -> pa.Table:
fields = []

return conn
for field in df.schema:
if pa.types.is_timestamp(field.type):
if field.type.tz is not None and X_WREN_TIMEZONE in headers:
# change the timezone to the seesion timezone
fields.append(
pa.field(
field.name,
pa.timestamp(field.type.unit, tz=headers[X_WREN_TIMEZONE]),
nullable=True,
)
)
continue
if data_source == DataSource.mysql:
timezone = headers.get(X_WREN_TIMEZONE, "UTC")
# TODO: ibis mysql loss the timezone information
# we cast timestamp to timestamp with session timezone for mysql
fields.append(
pa.field(
field.name,
pa.timestamp(field.type.unit, tz=timezone),
nullable=True,
)
)
continue

# TODO: the field's nullable should be Ture if the value contains null but
# the arrow table produced by the ibis clickhouse connector always set nullable to False
# so we set nullable to True here to avoid the casting error
fields.append(
pa.field(
field.name,
field.type,
nullable=True,
)
)
return df.cast(pa.schema(fields))


def get_timezone_from_offset(offset: str) -> str:
if offset.startswith("+"):
offset = offset[1:] # Remove the leading '+' sign
def get_datafusion_context(headers: dict) -> datafusion.SessionContext:
config = datafusion.SessionConfig()
if X_WREN_TIMEZONE in headers:
config.set("datafusion.execution.time_zone", headers[X_WREN_TIMEZONE])
else:
# Default to UTC if no timezone is provided
config.set("datafusion.execution.time_zone", "UTC")

first = duckdb.execute(
"SELECT name, utc_offset FROM pg_timezone_names() WHERE utc_offset = ?",
[offset],
).fetchone()
if first is None:
raise ValueError(f"Invalid timezone offset: {offset}")
return first[0] # Return the timezone name
ctx = datafusion.SessionContext(config=config)
return ctx


def build_context(headers: Header) -> Context:
Expand Down Expand Up @@ -166,30 +197,29 @@ def update_response_headers(response, required_headers: dict):
response.headers[X_CACHE_OVERRIDE_AT] = required_headers[X_CACHE_OVERRIDE_AT]


def _quote_identifier(identifier: str) -> str:
identifier = identifier.replace('"', '""') # Escape double quotes
return f'"{identifier}"' if identifier else identifier


def _formater(field: pa.Field) -> str:
column_name = _quote_identifier(field.name)
if pa.types.is_decimal(field.type) or pa.types.is_floating(field.type):
return f"""
case when {column_name} = 0 then '0'
when length(CAST({column_name} AS VARCHAR)) > 15 then format('{{:.9g}}', {column_name})
else RTRIM(RTRIM(format('{{:.8f}}', {column_name}), '0'), '.')
end as {column_name}"""
if pa.types.is_decimal(field.type):
# TODO: maybe implement a to_char udf to fomrat decimal would be better
# Currently, if the nubmer is less than 1, it will show with exponential notation if the lenth of float digits is great than 7
# e.g. 0.0000123 will be shown without exponential notation but 0.0000123 will be shown with exponential notation 1.23e-6
return f"case when {column_name} = 0 then '0' else cast({column_name} as double) end as {column_name}"
elif pa.types.is_date(field.type):
return f"strftime({column_name}, '%Y-%m-%d') as {column_name}"
return f"to_char({column_name}, '%Y-%m-%d') as {column_name}"
elif pa.types.is_timestamp(field.type):
if field.type.tz is None:
return f"strftime({column_name}, '%Y-%m-%d %H:%M:%S.%f') as {column_name}"
return f"to_char({column_name}, '%Y-%m-%d %H:%M:%S%.6f') as {column_name}"
else:
return (
f"strftime({column_name}, '%Y-%m-%d %H:%M:%S.%f %Z') as {column_name}"
f"to_char({column_name}, '%Y-%m-%d %H:%M:%S%.6f %Z') as {column_name}"
)
elif pa.types.is_binary(field.type):
return f"to_hex({column_name}) as {column_name}"
return f"encode({column_name}, 'hex') as {column_name}"
elif pa.types.is_interval(field.type):
return f"cast({column_name} as varchar) as {column_name}"
return column_name


def _quote_identifier(identifier: str) -> str:
identifier = identifier.replace('"', '""') # Escape double quotes
return f'"{identifier}"' if identifier else identifier
24 changes: 22 additions & 2 deletions ibis-server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions ibis-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ gunicorn = "^23.0.0"
uvicorn-worker = "^0.3.0"
jinja2 = ">=3.1.6"
redshift_connector = "2.1.7"
datafusion = "^47.0.0"

[tool.poetry.group.dev.dependencies]
pytest = "8.3.5"
Expand Down
10 changes: 5 additions & 5 deletions ibis-server/tests/routers/v2/connector/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ async def test_query(client, manifest_str):
1,
370,
"O",
"172799.49",
172799.49,
"1996-01-02",
"1_370",
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000 UTC",
"2024-01-01 23:59:59.000000 +00:00",
None,
"616263",
]
Expand Down Expand Up @@ -317,7 +317,7 @@ async def test_interval(client, manifest_str):
)
assert response.status_code == 200
result = response.json()
assert result["data"][0] == ["9 years 4 months 100 days 01:00:00"]
assert result["data"][0] == ["112 mons 100 days 1 hours"]
assert result["dtypes"] == {"col": "month_day_nano_interval"}


Expand All @@ -332,7 +332,7 @@ async def test_avg_interval(client, manifest_str):
)
assert response.status_code == 200
result = response.json()
assert result["data"][0] == ["10484 days 08:54:14.4"]
assert result["data"][0] == ["10484 days 8 hours 54 mins 14.400000000 secs"]
assert result["dtypes"] == {"col": "month_day_nano_interval"}


Expand Down Expand Up @@ -362,7 +362,7 @@ async def test_custom_datatypes_no_overrides(client, manifest_str):
)
assert response.status_code == 200
result = response.json()
assert result["data"][0] == ["9 years 4 months 100 days 01:00:00"]
assert result["data"][0] == ["112 mons 100 days 1 hours"]
assert result["dtypes"] == {"col": "month_day_nano_interval"}


Expand Down
2 changes: 1 addition & 1 deletion ibis-server/tests/routers/v2/connector/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ async def test_query(client, manifest_str, clickhouse: ClickHouseContainer):
"1996-01-02",
"1_370",
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000 UTC",
"2024-01-01 23:59:59.000000 +00:00",
None,
"abc", # Clickhouse does not support bytea, so it is returned as string
]
Expand Down
2 changes: 1 addition & 1 deletion ibis-server/tests/routers/v2/connector/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def test_query(client, manifest_str, mssql: SqlServerContainer):
"1996-01-02",
"1_370",
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000 UTC",
"2024-01-01 23:59:59.000000 +00:00",
None,
"616263",
]
Expand Down
10 changes: 5 additions & 5 deletions ibis-server/tests/routers/v2/connector/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ async def test_query(client, manifest_str, mysql: MySqlContainer):
"172799.49",
"1996-01-02",
"1_370",
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000 +00:00",
"2024-01-01 23:59:59.000000 +00:00",
None,
"616263",
]
Expand All @@ -174,9 +174,9 @@ async def test_query(client, manifest_str, mysql: MySqlContainer):
"totalprice": "string",
"orderdate": "date32[day]",
"order_cust_key": "string",
"timestamp": "timestamp[us]",
"timestamptz": "timestamp[us]",
"test_null_time": "timestamp[us]",
"timestamp": "timestamp[us, tz=UTC]",
"timestamptz": "timestamp[us, tz=UTC]",
"test_null_time": "timestamp[us, tz=UTC]",
"bytea_column": "binary",
}

Expand Down
2 changes: 1 addition & 1 deletion ibis-server/tests/routers/v2/connector/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ async def test_query(client, manifest_str, oracle: OracleDbContainer):
"1996-01-02",
"1_370",
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000 UTC",
"2024-01-01 23:59:59.000000 +00:00",
None,
"616263",
]
Expand Down
Loading
Loading