Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ibis-server/app/custom_ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@


class BigQueryType(datatypes.BigQueryType):
default_interval_precision = "s"
# It's a workaround for the issue of ibs BQ connector not supporting interval precision.
# Set `h` to avoid ibis try to cast arrow interval to duration, which leads to an casting error of pyarrow.
# See: https://github.com/ibis-project/ibis/blob/main/ibis/formats/pyarrow.py#L182
default_interval_precision = "h"


datatypes.BigQueryType = BigQueryType
5 changes: 5 additions & 0 deletions ibis-server/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

X_WREN_FALLBACK_DISABLE = "x-wren-fallback_disable"
X_WREN_VARIABLE_PREFIX = "x-wren-variable-"
X_WREN_TIMEZONE = "x-wren-timezone"
X_CACHE_HIT = "X-Cache-Hit"
X_CACHE_CREATE_AT = "X-Cache-Create-At"
X_CACHE_OVERRIDE = "X-Cache-Override"
X_CACHE_OVERRIDE_AT = "X-Cache-Override-At"


# Rebuild model to validate the dto is correct via validation of the pydantic
Expand Down
20 changes: 11 additions & 9 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ibis.expr.schema as sch
import ibis.formats
import pandas as pd
import pyarrow as pa
import sqlglot.expressions as sge
from duckdb import HTTPException, IOException
from google.cloud import bigquery
Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo):
else:
self._connector = SimpleConnector(data_source, connection_info)

def query(self, sql: str, limit: int) -> pd.DataFrame:
def query(self, sql: str, limit: int) -> pa.Table:
return self._connector.query(sql, limit)

def dry_run(self, sql: str) -> None:
Expand All @@ -75,8 +76,8 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo):
self.connection = self.data_source.get_connection(connection_info)

@tracer.start_as_current_span("connector_query", kind=trace.SpanKind.CLIENT)
def query(self, sql: str, limit: int) -> pd.DataFrame:
return self.connection.sql(sql).limit(limit).to_pandas()
def query(self, sql: str, limit: int) -> pa.Table:
return self.connection.sql(sql).limit(limit).to_pyarrow()

@tracer.start_as_current_span("connector_dry_run", kind=trace.SpanKind.CLIENT)
def dry_run(self, sql: str) -> None:
Expand Down Expand Up @@ -118,7 +119,7 @@ def __init__(self, connection_info: ConnectionInfo):
def query(self, sql: str, limit: int) -> pd.DataFrame:
# Canner enterprise does not support `CREATE TEMPORARY VIEW` for getting schema
schema = self._get_schema(sql)
return self.connection.sql(sql, schema=schema).limit(limit).to_pandas()
return self.connection.sql(sql, schema=schema).limit(limit).to_pyarrow()

@tracer.start_as_current_span("connector_dry_run", kind=trace.SpanKind.CLIENT)
def dry_run(self, sql: str) -> Any:
Expand Down Expand Up @@ -146,7 +147,7 @@ def __init__(self, connection_info: ConnectionInfo):
super().__init__(DataSource.bigquery, connection_info)
self.connection_info = connection_info

def query(self, sql: str, limit: int) -> pd.DataFrame:
def query(self, sql: str, limit: int) -> pa.Table:
try:
return super().query(sql, limit)
except ValueError as e:
Expand Down Expand Up @@ -200,9 +201,9 @@ def __init__(self, connection_info: ConnectionInfo):
init_duckdb_gcs(self.connection, connection_info)

@tracer.start_as_current_span("duckdb_query", kind=trace.SpanKind.INTERNAL)
def query(self, sql: str, limit: int) -> pd.DataFrame:
def query(self, sql: str, limit: int) -> pa.Table:
try:
return self.connection.execute(sql).fetch_df().head(limit)
return self.connection.execute(sql).fetch_arrow_table().slice(length=limit)
except IOException as e:
raise UnprocessableEntityError(f"Failed to execute query: {e!s}")
except HTTPException as e:
Expand Down Expand Up @@ -244,12 +245,13 @@ def __init__(self, connection_info: RedshiftConnectionUnion):
raise ValueError("Invalid Redshift connection_info type")

@tracer.start_as_current_span("connector_query", kind=trace.SpanKind.CLIENT)
def query(self, sql: str, limit: int) -> pd.DataFrame:
def query(self, sql: str, limit: int) -> pa.Table:
with closing(self.connection.cursor()) as cursor:
cursor.execute(sql)
cols = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
return pd.DataFrame(rows, columns=cols).head(limit)
df = pd.DataFrame(rows, columns=cols).head(limit)
return pa.Table.from_pandas(df)

@tracer.start_as_current_span("connector_dry_run", kind=trace.SpanKind.CLIENT)
def dry_run(self, sql: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion ibis-server/app/model/metadata/object_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def get_version(self):
def _get_connection(self):
conn = duckdb.connect()
init_duckdb_gcs(conn, self.connection_info)
logger.debug("Initialized duckdb minio")
logger.debug("Initialized duckdb gcs")
return conn

def _get_dal_operator(self):
Expand Down
4 changes: 2 additions & 2 deletions ibis-server/app/model/metadata/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_table_list(self) -> list[Table]:
# we reuse the connector.query method to execute the SQL
# so we have to give a limit to avoid too many rows
# the default limit for tables metadata is 500, i think it's a sensible limit
response = self.connector.query(sql, limit=500).to_dict(orient="records")
response = self.connector.query(sql, limit=500).to_pylist()

unique_tables = {}
for row in response:
Expand Down Expand Up @@ -104,7 +104,7 @@ def get_constraints(self) -> list[Constraint]:
# we reuse the connector.query method to execute the SQL
# so we have to give a limit to avoid too many rows
# the default limit for constraints metadata is 500, i think it's a sensible limit
response = self.connector.query(sql, limit=500).to_dict(orient="records")
response = self.connector.query(sql, limit=500).to_pylist()
constraints = []
for row in response:
constraints.append(
Expand Down
2 changes: 1 addition & 1 deletion ibis-server/app/model/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def format_result(result):
)
try:
rewritten_sql = await self.rewriter.rewrite(sql)
result = self.connector.query(rewritten_sql, limit=1)
result = self.connector.query(rewritten_sql, limit=1).to_pandas()
if not result.get("result").get(0):
raise ValidationError(
f"Relationship {relationship_name} is not valid: {format_result(result)}"
Expand Down
27 changes: 18 additions & 9 deletions ibis-server/app/query_cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import ibis
import opendal
import pyarrow as pa
from duckdb import DuckDBPyConnection, connect
from loguru import logger
from opentelemetry import trace

Expand All @@ -24,10 +26,9 @@ def get(self, data_source: str, sql: str, info) -> Optional[Any]:
# Check if cache file exists
if op.exists(cache_file_name):
try:
logger.info(f"\nReading query cache {cache_file_name}\n")
cache = ibis.read_parquet(full_path)
df = cache.execute()
logger.info("\nquery cache to dataframe\n")
logger.info(f"Reading query cache {cache_file_name}")
df = ibis.read_parquet(full_path).to_pyarrow()
logger.info("query cache to dataframe")
return df
except Exception as e:
logger.debug(f"Failed to read query cache {e}")
Expand All @@ -36,19 +37,24 @@ def get(self, data_source: str, sql: str, info) -> Optional[Any]:
return None

@tracer.start_as_current_span("set_cache", kind=trace.SpanKind.INTERNAL)
def set(self, data_source: str, sql: str, result: Any, info) -> None:
def set(
self,
data_source: str,
sql: str,
result: pa.Table,
info,
) -> None:
cache_key = self._generate_cache_key(data_source, sql, info)
cache_file_name = self._set_cache_file_name(cache_key)
op = self._get_dal_operator()
full_path = self._get_full_path(cache_file_name)

try:
# Create cache directory if it doesn't exist
with op.open(cache_file_name, mode="wb") as file:
cache = ibis.memtable(result)
logger.info(f"\nWriting query cache to {cache_file_name}\n")
con = self._get_duckdb_connection()
arrow_table = con.from_arrow(result)
if file.writable():
cache.to_parquet(full_path)
arrow_table.write_parquet(full_path)
except Exception as e:
logger.debug(f"Failed to write query cache: {e}")
return
Expand Down Expand Up @@ -103,3 +109,6 @@ def _get_full_path(self, path: str) -> str:
def _get_dal_operator(self) -> Any:
# Default implementation using local filesystem
return opendal.Operator("fs", root=self.root)

def _get_duckdb_connection(self) -> DuckDBPyConnection:
return connect()
38 changes: 27 additions & 11 deletions ibis-server/app/routers/v2/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
from opentelemetry import trace
from starlette.datastructures import Headers

from app.dependencies import get_wren_headers, verify_query_dto
from app.dependencies import (
X_CACHE_CREATE_AT,
X_CACHE_HIT,
X_CACHE_OVERRIDE,
X_CACHE_OVERRIDE_AT,
get_wren_headers,
verify_query_dto,
)
from app.mdl.java_engine import JavaEngineConnector
from app.mdl.rewriter import Rewriter
from app.mdl.substitute import ModelSubstitute
Expand All @@ -28,6 +35,7 @@
pushdown_limit,
set_attribute,
to_json,
update_response_headers,
)

router = APIRouter(prefix="/connector", tags=["connector"])
Expand Down Expand Up @@ -107,12 +115,13 @@ async def query(
)
cache_hit = cached_result is not None

cache_headers = {}
# case 1: cache hit read
if cache_enable and cache_hit and not override_cache:
span.add_event("cache hit")
response = ORJSONResponse(to_json(cached_result))
response.headers["X-Cache-Hit"] = "true"
response.headers["X-Cache-Create-At"] = str(
result = cached_result
cache_headers[X_CACHE_HIT] = "true"
cache_headers[X_CACHE_CREATE_AT] = str(
query_cache_manager.get_cache_file_timestamp(
data_source, dto.sql, dto.connection_info
)
Expand All @@ -126,25 +135,27 @@ async def query(
).rewrite(sql)
connector = Connector(data_source, dto.connection_info)
result = connector.query(rewritten_sql, limit=limit)
response = ORJSONResponse(to_json(result))

# headers for all non-hit cases
response.headers["X-Cache-Hit"] = "false"
cache_headers[X_CACHE_HIT] = "false"

match (cache_enable, cache_hit, override_cache):
# case 2 cache hit but override cache
case (True, True, True):
response.headers["X-Cache-Create-At"] = str(
cache_headers[X_CACHE_CREATE_AT] = str(
query_cache_manager.get_cache_file_timestamp(
data_source, dto.sql, dto.connection_info
)
)
query_cache_manager.set(
data_source, dto.sql, result, dto.connection_info
data_source,
dto.sql,
result,
dto.connection_info,
)

response.headers["X-Cache-Override"] = "true"
response.headers["X-Cache-Override-At"] = str(
cache_headers[X_CACHE_OVERRIDE] = "true"
cache_headers[X_CACHE_OVERRIDE_AT] = str(
query_cache_manager.get_cache_file_timestamp(
data_source, dto.sql, dto.connection_info
)
Expand All @@ -153,11 +164,16 @@ async def query(
# no matter the cache override or not, we need to create cache
case (True, False, _):
query_cache_manager.set(
data_source, dto.sql, result, dto.connection_info
data_source,
dto.sql,
result,
dto.connection_info,
)
# case 5~8 Other cases (cache is not enabled)
case (False, _, _):
pass
response = ORJSONResponse(to_json(result, headers))
update_response_headers(response, cache_headers)

if is_fallback:
get_fallback_message(
Expand Down
24 changes: 16 additions & 8 deletions ibis-server/app/routers/v3/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

from app.config import get_config
from app.dependencies import (
X_CACHE_CREATE_AT,
X_CACHE_HIT,
X_CACHE_OVERRIDE,
X_CACHE_OVERRIDE_AT,
X_WREN_FALLBACK_DISABLE,
exist_wren_variables_header,
get_wren_headers,
Expand Down Expand Up @@ -36,6 +40,7 @@
safe_strtobool,
set_attribute,
to_json,
update_response_headers,
)

router = APIRouter(prefix="/connector", tags=["connector"])
Expand Down Expand Up @@ -98,12 +103,14 @@ async def query(
data_source, dto.sql, dto.connection_info
)
cache_hit = cached_result is not None

cache_headers = {}
# case 1: cache hit read
if cache_enable and cache_hit and not override_cache:
span.add_event("cache hit")
response = ORJSONResponse(to_json(cached_result))
response.headers["X-Cache-Hit"] = "true"
response.headers["X-Cache-Create-At"] = str(
result = cached_result
cache_headers[X_CACHE_HIT] = "true"
cache_headers[X_CACHE_CREATE_AT] = str(
query_cache_manager.get_cache_file_timestamp(
data_source, dto.sql, dto.connection_info
)
Expand All @@ -119,15 +126,14 @@ async def query(
).rewrite(sql)
connector = Connector(data_source, dto.connection_info)
result = connector.query(rewritten_sql, limit=limit)
response = ORJSONResponse(to_json(result))

# headers for all non-hit cases
response.headers["X-Cache-Hit"] = "false"
cache_headers[X_CACHE_HIT] = "false"

match (cache_enable, cache_hit, override_cache):
# case 2: override existing cache
case (True, True, True):
response.headers["X-Cache-Create-At"] = str(
cache_headers[X_CACHE_CREATE_AT] = str(
query_cache_manager.get_cache_file_timestamp(
data_source, dto.sql, dto.connection_info
)
Expand All @@ -136,8 +142,8 @@ async def query(
data_source, dto.sql, result, dto.connection_info
)

response.headers["X-Cache-Override"] = "true"
response.headers["X-Cache-Override-At"] = str(
cache_headers[X_CACHE_OVERRIDE] = "true"
cache_headers[X_CACHE_OVERRIDE_AT] = str(
query_cache_manager.get_cache_file_timestamp(
data_source, dto.sql, dto.connection_info
)
Expand All @@ -152,6 +158,8 @@ async def query(
case (False, _, _):
pass

response = ORJSONResponse(to_json(result, headers))
update_response_headers(response, cache_headers)
return response
except Exception as e:
is_fallback_disable = bool(
Expand Down
Loading
Loading