From a541d2f5f96bdbf45e45e41a306fcb67c2e094a8 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 2 Apr 2025 17:56:31 +0800 Subject: [PATCH 1/2] simplify the function info --- wren-core-py/src/context.rs | 71 ++++++++++-------------- wren-core-py/tests/test_modeling_core.py | 14 ++--- 2 files changed, 35 insertions(+), 50 deletions(-) diff --git a/wren-core-py/src/context.rs b/wren-core-py/src/context.rs index 45cfc0023..a91d39e37 100644 --- a/wren-core-py/src/context.rs +++ b/wren-core-py/src/context.rs @@ -26,9 +26,8 @@ use std::str::FromStr; use std::sync::Arc; use std::vec; use tokio::runtime::Runtime; -use wren_core::array::{AsArray, GenericByteArray}; +use wren_core::array::AsArray; use wren_core::ast::{visit_statements_mut, Expr, Statement, Value}; -use wren_core::datatypes::GenericStringType; use wren_core::dialect::GenericDialect; use wren_core::mdl::context::create_ctx_with_mdl; use wren_core::mdl::function::{ @@ -235,34 +234,21 @@ impl PySessionContext { } } + /// Get the registered functions in the session context. + /// Only return `name`, `function_type`, and `description`. + /// The `name` is the name of the function. + /// The `function_type` is the type of the function. (e.g. scalar, aggregate, window) + /// The `description` is the description of the function. async fn get_regietered_functions( ctx: &wren_core::SessionContext, - ) -> PyResult> { + ) -> PyResult> { let sql = r#" - WITH inputs AS ( - SELECT - r.specific_name, - r.data_type as return_type, - pi.rid, - array_agg(pi.parameter_name order by pi.ordinal_position) as param_names, - array_agg(pi.data_type order by pi.ordinal_position) as param_types - FROM - information_schema.routines r - JOIN - information_schema.parameters pi ON r.specific_name = pi.specific_name AND pi.parameter_mode = 'IN' - GROUP BY 1, 2, 3 - ) - SELECT + SELECT DISTINCT r.routine_name as name, - i.param_names, - i.param_types, - r.data_type as return_type, r.function_type, r.description FROM information_schema.routines r - LEFT JOIN - inputs i ON r.specific_name = i.specific_name "#; let batches = ctx .sql(sql) @@ -275,27 +261,16 @@ impl PySessionContext { for batch in batches { let name_array = batch.column(0).as_string::(); - let param_names_array = batch.column(1).as_list::(); - let param_types_array = batch.column(2).as_list::(); - let return_type_array = batch.column(3).as_string::(); - let function_type_array = batch.column(4).as_string::(); - let description_array = batch.column(5).as_string::(); + let function_type_array = batch.column(1).as_string::(); + let description_array = batch.column(2).as_string::(); for row in 0..batch.num_rows() { let name = name_array.value(row).to_string(); - let param_names = - Self::to_string_vec(param_names_array.value(row).as_string::()); - let param_types = - Self::to_string_vec(param_types_array.value(row).as_string::()); - let return_type = return_type_array.value(row).to_string(); let description = description_array.value(row).to_string(); let function_type = function_type_array.value(row).to_string(); - functions.push(RemoteFunction { + functions.push(RemoteFunctionDto { name, - param_names: Some(param_names), - param_types: Some(param_types), - return_type, description: Some(description), function_type: FunctionType::from_str(&function_type).unwrap(), }); @@ -303,13 +278,23 @@ impl PySessionContext { } Ok(functions) } +} + +struct RemoteFunctionDto { + name: String, + function_type: FunctionType, + description: Option, +} - fn to_string_vec( - array: &GenericByteArray>, - ) -> Vec> { - array - .iter() - .map(|s| s.map(|s| s.to_string())) - .collect::>>() +impl From for PyRemoteFunction { + fn from(remote_function: RemoteFunctionDto) -> Self { + Self { + function_type: remote_function.function_type.to_string(), + name: remote_function.name, + return_type: None, + param_names: None, + param_types: None, + description: remote_function.description, + } } } diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index c2ae0b260..49f5dc2c8 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -106,7 +106,7 @@ def test_read_function_list(): path = "tests/functions.csv" session_context = SessionContext(manifest_str, path) functions = session_context.get_available_functions() - assert len(functions) == 25948 + assert len(functions) == 283 rewritten_sql = session_context.transform_sql( "SELECT add_two(c_custkey, c_custkey) FROM my_catalog.my_schema.customer" @@ -118,7 +118,7 @@ def test_read_function_list(): session_context = SessionContext(manifest_str, None) functions = session_context.get_available_functions() - assert len(functions) == 25941 + assert len(functions) == 276 def test_get_available_functions(): @@ -128,9 +128,9 @@ def test_get_available_functions(): assert add_two.name == "add_two" assert add_two.function_type == "scalar" assert add_two.description == "Adds two numbers together." - assert add_two.return_type == "Int32" - assert add_two.param_names == "f1,f2" - assert add_two.param_types == "Int32,Int32" + assert add_two.return_type is None + assert add_two.param_names is None + assert add_two.param_types is None max_if = next(f for f in functions if f.name == "max_if") assert max_if.name == "max_if" @@ -142,9 +142,9 @@ def test_get_available_functions(): assert func.name == "add_custom" assert func.function_type == "scalar" assert func.description == "Adds two numbers together." - assert func.return_type == "Int32" + assert func.return_type is None assert func.param_names is None - assert func.param_types == "Int32,Int32" + assert func.param_types is None func = next(f for f in functions if f.name == "test_same_as_input_array") assert func.name == "test_same_as_input_array" From 29cd0dda689e34dc06c85d4f91adf8c636d2df82 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 2 Apr 2025 18:11:19 +0800 Subject: [PATCH 2/2] update function test --- ibis-server/tests/conftest.py | 2 +- .../routers/v3/connector/bigquery/test_functions.py | 9 ++++----- .../routers/v3/connector/local_file/test_functions.py | 6 +++--- .../tests/routers/v3/connector/mysql/test_functions.py | 4 ++-- .../routers/v3/connector/postgres/test_functions.py | 4 ++-- .../tests/routers/v3/connector/trino/test_functions.py | 4 ++-- 6 files changed, 14 insertions(+), 15 deletions(-) diff --git a/ibis-server/tests/conftest.py b/ibis-server/tests/conftest.py index ece69a1f0..57c030980 100644 --- a/ibis-server/tests/conftest.py +++ b/ibis-server/tests/conftest.py @@ -11,7 +11,7 @@ def file_path(path: str) -> str: return os.path.join(os.path.dirname(__file__), path) -DATAFUSION_FUNCTION_COUNT = 25941 +DATAFUSION_FUNCTION_COUNT = 276 @pytest.fixture(scope="session") diff --git a/ibis-server/tests/routers/v3/connector/bigquery/test_functions.py b/ibis-server/tests/routers/v3/connector/bigquery/test_functions.py index 320c6dcf9..16b99bbfa 100644 --- a/ibis-server/tests/routers/v3/connector/bigquery/test_functions.py +++ b/ibis-server/tests/routers/v3/connector/bigquery/test_functions.py @@ -48,8 +48,7 @@ async def test_function_list(client): assert len(result) == DATAFUSION_FUNCTION_COUNT + 34 the_func = next( filter( - lambda x: x["name"] == "string_agg" - and x["param_types"] == "LargeUtf8,LargeUtf8", + lambda x: x["name"] == "string_agg", result, ) ) @@ -57,9 +56,9 @@ async def test_function_list(client): "name": "string_agg", "description": "Concatenates the values of string expressions and places separator values between them.", "function_type": "aggregate", - "param_names": "expression,delimiter", - "param_types": "LargeUtf8,LargeUtf8", - "return_type": "LargeUtf8", + "param_names": None, + "param_types": None, + "return_type": None, } config.set_remote_function_list_path(None) diff --git a/ibis-server/tests/routers/v3/connector/local_file/test_functions.py b/ibis-server/tests/routers/v3/connector/local_file/test_functions.py index b84f9201d..735b9a122 100644 --- a/ibis-server/tests/routers/v3/connector/local_file/test_functions.py +++ b/ibis-server/tests/routers/v3/connector/local_file/test_functions.py @@ -60,9 +60,9 @@ async def test_function_list(client): "name": "regexp_escape", "description": "Escapes all potentially meaningful regexp characters in the input string", "function_type": "scalar", - "param_names": "string", - "param_types": "Utf8", - "return_type": "Utf8", + "param_names": None, + "param_types": None, + "return_type": None, } config.set_remote_function_list_path(None) diff --git a/ibis-server/tests/routers/v3/connector/mysql/test_functions.py b/ibis-server/tests/routers/v3/connector/mysql/test_functions.py index c4c36b67f..e219a1488 100644 --- a/ibis-server/tests/routers/v3/connector/mysql/test_functions.py +++ b/ibis-server/tests/routers/v3/connector/mysql/test_functions.py @@ -61,8 +61,8 @@ async def test_function_list(client): "description": "Synonym for LOWER()", "function_type": "scalar", "param_names": None, - "param_types": "Utf8", - "return_type": "Utf8", + "param_types": None, + "return_type": None, } config.set_remote_function_list_path(None) diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_functions.py b/ibis-server/tests/routers/v3/connector/postgres/test_functions.py index 7c42f3c0e..83d7d98ab 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_functions.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_functions.py @@ -62,8 +62,8 @@ async def test_function_list(client): "description": "Get subfield from date/time", "function_type": "scalar", "param_names": None, - "param_types": "Utf8,Timestamp(Nanosecond, None)", - "return_type": "Decimal128(38, 10)", + "param_types": None, + "return_type": None, } config.set_remote_function_list_path(None) diff --git a/ibis-server/tests/routers/v3/connector/trino/test_functions.py b/ibis-server/tests/routers/v3/connector/trino/test_functions.py index c7140be20..a04a72eb1 100644 --- a/ibis-server/tests/routers/v3/connector/trino/test_functions.py +++ b/ibis-server/tests/routers/v3/connector/trino/test_functions.py @@ -60,8 +60,8 @@ async def test_function_list(client): "description": "Converts binary to base64", "function_type": "scalar", "param_names": None, - "param_types": "Binary", - "return_type": "Utf8", + "param_types": None, + "return_type": None, } config.set_remote_function_list_path(None)