diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index 20115fead..d0c2b454c 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -443,6 +443,28 @@ def functions( return ORJSONResponse(func_list) +@router.get( + "/{data_source}/function/{function_name}", + description="get the available function list of the specified data source", +) +def function( + data_source: DataSource, + function_name: str, + headers: Annotated[Headers, Depends(get_wren_headers)] = None, +) -> Response: + span_name = f"v3_get_function_{data_source}" + with tracer.start_as_current_span( + name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) + ) as span: + set_attribute(headers, span) + file_path = get_config().get_remote_function_list_path(data_source) + session_context = get_session_context(None, file_path) + func_list = [ + f.to_dict() for f in session_context.get_available_function(function_name) + ] + return ORJSONResponse(func_list) + + @router.post( "/{data_source}/model-substitute", description="get the SQL which table name is substituted", diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_functions.py b/ibis-server/tests/routers/v3/connector/postgres/test_functions.py index 2524afe86..ac7adf5ac 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_functions.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_functions.py @@ -70,6 +70,22 @@ async def test_function_list(client): assert len(result) == DATAFUSION_FUNCTION_COUNT +async def test_get_function(client): + response = await client.get(url=f"{base_url}/function/div") + assert response.status_code == 200 + result = response.json() + assert result == [ + { + "name": "div", + "description": "trunc(x/y)", + "function_type": "scalar", + "param_names": None, + "param_types": "Decimal(38, 10),Decimal(38, 10)", + "return_type": "Decimal(38, 10)", + } + ] + + async def test_scalar_function(client, manifest_str: str, connection_info): response = await client.post( url=f"{base_url}/query", diff --git a/wren-core-py/src/context.rs b/wren-core-py/src/context.rs index 02148175e..b041c840e 100644 --- a/wren-core-py/src/context.rs +++ b/wren-core-py/src/context.rs @@ -29,8 +29,9 @@ use std::str::FromStr; use std::sync::Arc; use std::vec; use tokio::runtime::Runtime; -use wren_core::array::AsArray; +use wren_core::array::{AsArray, GenericByteArray}; use wren_core::ast::{visit_statements_mut, Expr, Statement, Value, ValueWithSpan}; +use wren_core::datatypes::GenericStringType; use wren_core::dialect::GenericDialect; use wren_core::mdl::context::apply_wren_on_ctx; use wren_core::mdl::function::{ @@ -239,6 +240,20 @@ impl PySessionContext { Ok(registered_functions) } + pub fn get_available_function( + &self, + function_name: &str, + ) -> PyResult> { + let functions = self + .runtime + .block_on(Self::get_registered_function(function_name, &self.exec_ctx)) + .map_err(CoreError::from)? + .into_iter() + .map(|f| PyRemoteFunction::from(f)) + .collect::>(); + Ok(functions) + } + /// Push down the limit to the given SQL. /// If the limit is None, the SQL will be returned as is. /// If the limit is greater than the pushdown limit, the limit will be replaced with the pushdown limit. @@ -407,6 +422,90 @@ impl PySessionContext { } Ok(()) } + + async fn get_registered_function( + function_name: &str, + ctx: &wren_core::SessionContext, + ) -> PyResult> { + let sql = format!( + r#" + WITH inputs AS ( + SELECT + r.specific_name, + r.data_type as return_type, + pi.rid, + array_agg(pi.parameter_name order by pi.ordinal_position) as param_names, + array_agg(pi.data_type order by pi.ordinal_position) as param_types + FROM + information_schema.routines r + JOIN + information_schema.parameters pi ON r.specific_name = pi.specific_name AND pi.parameter_mode = 'IN' + GROUP BY 1, 2, 3 + ) + SELECT + r.routine_name as name, + i.param_names, + i.param_types, + r.data_type as return_type, + r.function_type, + r.description + FROM + information_schema.routines r + LEFT JOIN + inputs i ON r.specific_name = i.specific_name + WHERE + r.routine_name = '{}' + "#, + function_name + ); + let batches = ctx + .sql(&sql) + .await + .map_err(CoreError::from)? + .collect() + .await + .map_err(CoreError::from)?; + let mut functions = vec![]; + + for batch in batches { + let name_array = batch.column(0).as_string::(); + let param_names_array = batch.column(1).as_list::(); + let param_types_array = batch.column(2).as_list::(); + let return_type_array = batch.column(3).as_string::(); + let function_type_array = batch.column(4).as_string::(); + let description_array = batch.column(5).as_string::(); + + for row in 0..batch.num_rows() { + let name = name_array.value(row).to_string(); + let param_names = + Self::to_string_vec(param_names_array.value(row).as_string::()); + let param_types = + Self::to_string_vec(param_types_array.value(row).as_string::()); + let return_type = return_type_array.value(row).to_string(); + let description = description_array.value(row).to_string(); + let function_type = function_type_array.value(row).to_string(); + + functions.push(RemoteFunction { + name, + param_names: Some(param_names), + param_types: Some(param_types), + return_type, + description: Some(description), + function_type: FunctionType::from_str(&function_type).unwrap(), + }); + } + } + Ok(functions) + } + + fn to_string_vec( + array: &GenericByteArray>, + ) -> Vec> { + array + .iter() + .map(|s| s.map(|s| s.to_string())) + .collect::>>() + } } struct RemoteFunctionDto {