Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ibis-server/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def update(self, diagnose: bool):
def get_remote_function_list_path(self, data_source: str) -> str:
if not self.remote_function_list_path:
return None

# The function list has been defined by Wren Core
if data_source in {"bigquery"}:
return None

if data_source in {"local_file", "s3_file", "minio_file", "gcs_file"}:
data_source = "duckdb"
base_path = os.path.normpath(self.remote_function_list_path)
Expand Down
9 changes: 7 additions & 2 deletions ibis-server/app/mdl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@

@cache
def get_session_context(
manifest_str: str | None, function_path: str, properties: frozenset | None = None
manifest_str: str | None,
function_path: str,
properties: frozenset | None = None,
data_source: str | None = None,
) -> wren_core.SessionContext:
return wren_core.SessionContext(manifest_str, function_path, properties)
return wren_core.SessionContext(
manifest_str, function_path, properties, data_source
)


def get_manifest_extractor(manifest_str: str) -> wren_core.ManifestExtractor:
Expand Down
22 changes: 17 additions & 5 deletions ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(
self.properties = properties
if experiment:
function_path = get_config().get_remote_function_list_path(data_source)
self._rewriter = EmbeddedEngineRewriter(function_path)
self._rewriter = EmbeddedEngineRewriter(
function_path=function_path, data_source=data_source
)
else:
self._rewriter = ExternalEngineRewriter(java_engine_connector)

Expand Down Expand Up @@ -130,7 +132,8 @@ def handle_extract_exception(e: Exception):


class EmbeddedEngineRewriter:
def __init__(self, function_path: str):
def __init__(self, function_path: str, data_source: DataSource = None):
self.data_source = data_source
self.function_path = function_path

@tracer.start_as_current_span("embedded_rewrite", kind=trace.SpanKind.INTERNAL)
Expand All @@ -140,7 +143,10 @@ async def rewrite(
try:
processed_properties = self.get_session_properties(properties)
session_context = get_session_context(
manifest_str, self.function_path, processed_properties
manifest_str,
self.function_path,
processed_properties,
self.data_source.name if self.data_source else None,
)
return await to_thread.run_sync(
session_context.transform_sql,
Expand All @@ -151,12 +157,18 @@ async def rewrite(

@tracer.start_as_current_span("embedded_rewrite", kind=trace.SpanKind.INTERNAL)
def rewrite_sync(
self, manifest_str: str, sql: str, properties: dict | None = None
self,
manifest_str: str,
sql: str,
properties: dict | None = None,
) -> str:
try:
processed_properties = self.get_session_properties(properties)
session_context = get_session_context(
manifest_str, self.function_path, processed_properties
manifest_str,
self.function_path,
processed_properties,
self.data_source.name if self.data_source else None,
)
return session_context.transform_sql(sql)
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions ibis-server/app/model/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel, Field
from starlette.status import (
HTTP_404_NOT_FOUND,
HTTP_422_UNPROCESSABLE_ENTITY,
HTTP_422_UNPROCESSABLE_CONTENT,
HTTP_500_INTERNAL_SERVER_ERROR,
HTTP_501_NOT_IMPLEMENTED,
HTTP_502_BAD_GATEWAY,
Expand Down Expand Up @@ -109,7 +109,7 @@ def get_http_status_code(self) -> int:
return HTTP_504_GATEWAY_TIMEOUT
case e:
if e.value < 100:
return HTTP_422_UNPROCESSABLE_ENTITY
return HTTP_422_UNPROCESSABLE_CONTENT
return HTTP_500_INTERNAL_SERVER_ERROR


Expand Down
44 changes: 0 additions & 44 deletions ibis-server/resources/function_list/bigquery.csv

This file was deleted.

2 changes: 1 addition & 1 deletion ibis-server/tools/query_local_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
print("### Session Properties ###")
for key, value in properties:
print(f"# {key}: {value}")
session_context = SessionContext(encoded_str, function_list_path + f"/{data_source}.csv", properties)
session_context = SessionContext(encoded_str, function_list_path + f"/{data_source}.csv", properties, data_source)
planned_sql = session_context.transform_sql(sql)
print("# Planned SQL:\n", planned_sql)

Expand Down
4 changes: 3 additions & 1 deletion ibis-server/wren/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def plan(self, input_sql):
)

self.planned_sql = self.context.rewriter.rewrite_sync(
self.manifest, self.wren_sql, self.properties
self.manifest,
self.wren_sql,
self.properties,
)

read = self._get_read_dialect()
Expand Down
63 changes: 61 additions & 2 deletions wren-core-base/src/mdl/manifest.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::error::Error;
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
Expand All @@ -17,6 +18,7 @@
* under the License.
*/
use std::fmt::Display;
use std::str::FromStr;
use std::sync::Arc;

#[cfg(not(feature = "python-binding"))]
Expand Down Expand Up @@ -99,6 +101,32 @@ mod manifest_impl {

pub use crate::mdl::manifest::manifest_impl::*;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParsedDataSourceError {
pub message: String,
}

impl ParsedDataSourceError {
pub fn new(msg: &str) -> ParsedDataSourceError {
ParsedDataSourceError {
message: msg.to_string(),
}
}
}

impl Display for ParsedDataSourceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ParsedDataSourceError: {}", self.message)
}
}

impl Error for ParsedDataSourceError {
#[allow(deprecated)]
fn description(&self) -> &str {
&self.message
}
}

impl Display for DataSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand All @@ -124,6 +152,37 @@ impl Display for DataSource {
}
}

impl FromStr for DataSource {
type Err = ParsedDataSourceError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_uppercase().as_str() {
"BIGQUERY" => Ok(DataSource::BigQuery),
"CLICKHOUSE" => Ok(DataSource::Clickhouse),
"CANNER" => Ok(DataSource::Canner),
"TRINO" => Ok(DataSource::Trino),
"MSSQL" => Ok(DataSource::MSSQL),
"MYSQL" => Ok(DataSource::MySQL),
"POSTGRES" => Ok(DataSource::Postgres),
"SNOWFLAKE" => Ok(DataSource::Snowflake),
"DATAFUSION" => Ok(DataSource::Datafusion),
"DUCKDB" => Ok(DataSource::DuckDB),
"LOCAL_FILE" => Ok(DataSource::LocalFile),
"S3_FILE" => Ok(DataSource::S3File),
"GCS_FILE" => Ok(DataSource::GcsFile),
"MINIO_FILE" => Ok(DataSource::MinioFile),
"ORACLE" => Ok(DataSource::Oracle),
"ATHENA" => Ok(DataSource::Athena),
"REDSHIFT" => Ok(DataSource::Redshift),
"DATABRICKS" => Ok(DataSource::Databricks),
_ => Err(ParsedDataSourceError::new(&format!(
"Unknown data source: {}",
s
))),
}
}
}

mod table_reference {
use serde::{self, Deserialize, Deserializer, Serialize, Serializer};

Expand Down Expand Up @@ -260,7 +319,7 @@ impl Model {
self.columns
.iter()
.filter(|c| c.relationship.is_none())
.map(|c| Arc::clone(&c))
.map(Arc::clone)
.collect()
}
}
Expand All @@ -286,7 +345,7 @@ impl Model {
self.columns
.iter()
.find(|c| c.name == column_name)
.map(|c| Arc::clone(&c))
.map(Arc::clone)
}

/// Return the primary key of the model
Expand Down
Loading
Loading