Skip to content
This repository was archived by the owner on May 7, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ibis-server/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def update(self, diagnose: bool):
def get_remote_function_list_path(self, data_source: str) -> str:
if not self.remote_function_list_path:
return None

# The function list has been defined by Wren Core
if data_source in {"bigquery"}:
return None

if data_source in {"local_file", "s3_file", "minio_file", "gcs_file"}:
data_source = "duckdb"
base_path = os.path.normpath(self.remote_function_list_path)
Expand Down
9 changes: 7 additions & 2 deletions ibis-server/app/mdl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@

@cache
def get_session_context(
manifest_str: str | None, function_path: str, properties: frozenset | None = None
manifest_str: str | None,
function_path: str,
properties: frozenset | None = None,
data_source: str | None = None,
) -> wren_core.SessionContext:
return wren_core.SessionContext(manifest_str, function_path, properties)
return wren_core.SessionContext(
manifest_str, function_path, properties, data_source
)


def get_manifest_extractor(manifest_str: str) -> wren_core.ManifestExtractor:
Expand Down
12 changes: 9 additions & 3 deletions ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(
self.properties = properties
if experiment:
function_path = get_config().get_remote_function_list_path(data_source)
self._rewriter = EmbeddedEngineRewriter(function_path)
self._rewriter = EmbeddedEngineRewriter(
function_path=function_path, data_source=data_source
)
else:
self._rewriter = ExternalEngineRewriter(java_engine_connector)

Expand Down Expand Up @@ -130,7 +132,8 @@ def handle_extract_exception(e: Exception):


class EmbeddedEngineRewriter:
def __init__(self, function_path: str):
def __init__(self, function_path: str, data_source: DataSource = None):
self.data_source = data_source
Comment thread
goldmedal marked this conversation as resolved.
self.function_path = function_path

@tracer.start_as_current_span("embedded_rewrite", kind=trace.SpanKind.INTERNAL)
Expand All @@ -140,7 +143,10 @@ async def rewrite(
try:
processed_properties = self.get_session_properties(properties)
session_context = get_session_context(
manifest_str, self.function_path, processed_properties
manifest_str,
self.function_path,
processed_properties,
self.data_source.name if self.data_source else None,
)
return await to_thread.run_sync(
session_context.transform_sql,
Expand Down
4 changes: 2 additions & 2 deletions ibis-server/app/model/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel, Field
from starlette.status import (
HTTP_404_NOT_FOUND,
HTTP_422_UNPROCESSABLE_ENTITY,
HTTP_422_UNPROCESSABLE_CONTENT,
HTTP_500_INTERNAL_SERVER_ERROR,
HTTP_501_NOT_IMPLEMENTED,
HTTP_502_BAD_GATEWAY,
Expand Down Expand Up @@ -109,7 +109,7 @@ def get_http_status_code(self) -> int:
return HTTP_504_GATEWAY_TIMEOUT
case e:
if e.value < 100:
return HTTP_422_UNPROCESSABLE_ENTITY
return HTTP_422_UNPROCESSABLE_CONTENT
return HTTP_500_INTERNAL_SERVER_ERROR


Expand Down
44 changes: 0 additions & 44 deletions ibis-server/resources/function_list/bigquery.csv

This file was deleted.

2 changes: 1 addition & 1 deletion ibis-server/tools/query_local_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
print("### Session Properties ###")
for key, value in properties:
print(f"# {key}: {value}")
session_context = SessionContext(encoded_str, function_list_path + f"/{data_source}.csv", properties)
session_context = SessionContext(encoded_str, function_list_path + f"/{data_source}.csv", properties, data_source)
planned_sql = session_context.transform_sql(sql)
print("# Planned SQL:\n", planned_sql)

Expand Down
63 changes: 61 additions & 2 deletions wren-core-base/src/mdl/manifest.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::error::Error;
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
Expand All @@ -17,6 +18,7 @@
* under the License.
*/
use std::fmt::Display;
use std::str::FromStr;
use std::sync::Arc;

#[cfg(not(feature = "python-binding"))]
Expand Down Expand Up @@ -99,6 +101,32 @@ mod manifest_impl {

pub use crate::mdl::manifest::manifest_impl::*;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParsedDataSourceError {
pub message: String,
}

impl ParsedDataSourceError {
pub fn new(msg: &str) -> ParsedDataSourceError {
ParsedDataSourceError {
message: msg.to_string(),
}
}
}

impl Display for ParsedDataSourceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ParsedDataSourceError: {}", self.message)
}
}

impl Error for ParsedDataSourceError {
#[allow(deprecated)]
fn description(&self) -> &str {
&self.message
}
}

impl Display for DataSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand All @@ -124,6 +152,37 @@ impl Display for DataSource {
}
}

impl FromStr for DataSource {
type Err = ParsedDataSourceError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_uppercase().as_str() {
"BIGQUERY" => Ok(DataSource::BigQuery),
"CLICKHOUSE" => Ok(DataSource::Clickhouse),
"CANNER" => Ok(DataSource::Canner),
"TRINO" => Ok(DataSource::Trino),
"MSSQL" => Ok(DataSource::MSSQL),
"MYSQL" => Ok(DataSource::MySQL),
"POSTGRES" => Ok(DataSource::Postgres),
"SNOWFLAKE" => Ok(DataSource::Snowflake),
"DATAFUSION" => Ok(DataSource::Datafusion),
"DUCKDB" => Ok(DataSource::DuckDB),
"LOCAL_FILE" => Ok(DataSource::LocalFile),
"S3_FILE" => Ok(DataSource::S3File),
"GCS_FILE" => Ok(DataSource::GcsFile),
"MINIO_FILE" => Ok(DataSource::MinioFile),
"ORACLE" => Ok(DataSource::Oracle),
"ATHENA" => Ok(DataSource::Athena),
"REDSHIFT" => Ok(DataSource::Redshift),
"DATABRICKS" => Ok(DataSource::Databricks),
_ => Err(ParsedDataSourceError::new(&format!(
"Unknown data source: {}",
s
))),
}
}
}

mod table_reference {
use serde::{self, Deserialize, Deserializer, Serialize, Serializer};

Expand Down Expand Up @@ -260,7 +319,7 @@ impl Model {
self.columns
.iter()
.filter(|c| c.relationship.is_none())
.map(|c| Arc::clone(&c))
.map(Arc::clone)
.collect()
}
}
Expand All @@ -286,7 +345,7 @@ impl Model {
self.columns
.iter()
.find(|c| c.name == column_name)
.map(|c| Arc::clone(&c))
.map(Arc::clone)
}

/// Return the primary key of the model
Expand Down
113 changes: 80 additions & 33 deletions wren-core-py/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use wren_core::mdl::function::{
use wren_core::{
mdl, AggregateUDF, AnalyzedWrenMDL, ScalarUDF, SessionConfig, WindowUDF,
};
use wren_core_base::mdl::DataSource;

/// The Python wrapper for the Wren Core session context.
#[pyclass(name = "SessionContext")]
Expand Down Expand Up @@ -77,45 +78,27 @@ impl PySessionContext {
/// if `mdl_base64` is provided, the session context will be created with the given MDL. Otherwise, an empty MDL will be created.
/// if `remote_functions_path` is provided, the session context will be created with the remote functions defined in the CSV file.
#[new]
#[pyo3(signature = (mdl_base64=None, remote_functions_path=None, properties=None))]
#[pyo3(signature = (mdl_base64=None, remote_functions_path=None, properties=None, data_source=None))]
pub fn new(
mdl_base64: Option<&str>,
remote_functions_path: Option<&str>,
properties: Option<Py<PyAny>>,
data_source: Option<&str>,
) -> PyResult<Self> {
let remote_functions = Self::read_remote_function_list(remote_functions_path)
.map_err(CoreError::from)?;
let remote_functions: Vec<RemoteFunction> = remote_functions
.into_iter()
.map(|f| f.into())
.collect::<Vec<_>>();

let config = SessionConfig::default().with_information_schema(true);
let ctx = wren_core::mdl::create_wren_ctx(Some(config));
let runtime = Runtime::new().map_err(CoreError::from)?;

let registered_functions = runtime
.block_on(Self::get_registered_functions(&ctx))
.map(|functions| {
functions
.into_iter()
.map(|f| f.name)
.collect::<std::collections::HashSet<String>>()
})
.map_err(CoreError::from)?;

remote_functions
.into_iter()
.try_for_each(|remote_function| {
debug!("Registering remote function: {:?}", remote_function);
// TODO: check not only the name but also the return type and the parameter types
if !registered_functions.contains(&remote_function.name) {
Self::register_remote_function(&ctx, remote_function)?;
}
Ok::<(), CoreError>(())
})?;

let Some(mdl_base64) = mdl_base64 else {
let data_source = data_source
.map(|ds| DataSource::from_str(ds).map_err(CoreError::from))
.transpose()?;
let config = SessionConfig::default().with_information_schema(true);
let ctx = wren_core::mdl::create_wren_ctx(Some(config), data_source.as_ref());
Self::register_function_by_data_source(
data_source.as_ref(),
remote_functions_path,
&runtime,
&ctx,
)?;
return Ok(Self {
ctx: ctx.clone(),
exec_ctx: ctx,
Expand All @@ -125,7 +108,30 @@ impl PySessionContext {
});
};

Python::attach(|py| {
let manifest = to_manifest(mdl_base64)?;

// If the manifest has a data source, use it.
// Otherwise, if the data_source parameter is provided, use it.
// Otherwise, use None.
let data_source = if let Some(ds) = &manifest.data_source {
Some(*ds)
} else if let Some(ds_str) = data_source {
Some(DataSource::from_str(ds_str).map_err(CoreError::from)?)
} else {
None
};

let config = SessionConfig::default().with_information_schema(true);
let ctx = wren_core::mdl::create_wren_ctx(Some(config), data_source.as_ref());

Self::register_function_by_data_source(
data_source.as_ref(),
remote_functions_path,
&runtime,
&ctx,
)?;

Python::attach(|py: Python<'_>| {
let properties_map = if let Some(obj) = properties {
let obj = obj.as_ref();
if obj.is_none(py) {
Expand Down Expand Up @@ -159,7 +165,6 @@ impl PySessionContext {
} else {
HashMap::new()
};
let manifest = to_manifest(mdl_base64)?;
let properties_ref = Arc::new(properties_map);
match AnalyzedWrenMDL::analyze(
manifest,
Expand Down Expand Up @@ -360,6 +365,48 @@ impl PySessionContext {
}
Ok(functions)
}

fn register_function_by_data_source(
data_source: Option<&DataSource>,
remote_functions_path: Option<&str>,
runtime: &Runtime,
ctx: &wren_core::SessionContext,
) -> PyResult<()> {
match data_source {
Some(DataSource::BigQuery) => {}
_ => {
let remote_functions =
Self::read_remote_function_list(remote_functions_path)
.map_err(CoreError::from)?;
let remote_functions: Vec<RemoteFunction> = remote_functions
.into_iter()
.map(|f| f.into())
.collect::<Vec<_>>();

let registered_functions = runtime
.block_on(Self::get_registered_functions(ctx))
.map(|functions| {
functions
.into_iter()
.map(|f| f.name)
.collect::<std::collections::HashSet<String>>()
})
.map_err(CoreError::from)?;

remote_functions
.into_iter()
.try_for_each(|remote_function| {
debug!("Registering remote function: {:?}", remote_function);
// TODO: check not only the name but also the return type and the parameter types
if !registered_functions.contains(&remote_function.name) {
Self::register_remote_function(ctx, remote_function)?;
}
Ok::<(), CoreError>(())
})?;
}
}
Ok(())
}
}

struct RemoteFunctionDto {
Expand Down
Loading
Loading