Skip to content

Commit

Permalink
Add dummy UDFs for sqlintegration
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Feb 13, 2024
1 parent dfaa978 commit a355d4a
Showing 1 changed file with 61 additions and 4 deletions.
65 changes: 61 additions & 4 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::any::Any;
#[cfg(test)]
use std::collections::HashMap;
use std::{sync::Arc, vec};
Expand All @@ -29,7 +30,8 @@ use datafusion_common::{
use datafusion_common::{plan_err, ParamValues};
use datafusion_expr::{
logical_plan::{LogicalPlan, Prepare},
AggregateUDF, ScalarUDF, TableSource, WindowUDF,
AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, TableSource,
Volatility, WindowUDF,
};
use datafusion_sql::{
parser::DFParser,
Expand Down Expand Up @@ -2671,13 +2673,62 @@ fn logical_plan_with_dialect_and_options(
dialect: &dyn Dialect,
options: ParserOptions,
) -> Result<LogicalPlan> {
let context = MockContextProvider::default();
let context = MockContextProvider::default().with_udf(make_udf(
"nullif",
vec![DataType::Int32, DataType::Int32],
DataType::Int32,
));

let planner = SqlToRel::new_with_options(&context, options);
let result = DFParser::parse_sql_with_dialect(sql, dialect);
let mut ast = result?;
planner.statement_to_plan(ast.pop_front().unwrap())
}

fn make_udf(name: &'static str, args: Vec<DataType>, return_type: DataType) -> ScalarUDF {
ScalarUDF::new_from_impl(DummyUDF::new(name, args, return_type))
}

/// Mocked UDF
#[derive(Debug)]
struct DummyUDF {
name: &'static str,
signature: Signature,
return_type: DataType,
}

impl DummyUDF {
fn new(name: &'static str, args: Vec<DataType>, return_type: DataType) -> Self {
Self {
name,
signature: Signature::exact(args, Volatility::Immutable),
return_type,
}
}
}

impl ScalarUDFImpl for DummyUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!("DummyUDF::invoke")
}
}

/// Create logical plan, write with formatter, compare to expected output
fn quick_test(sql: &str, expected: &str) {
let plan = logical_plan(sql).unwrap();
Expand Down Expand Up @@ -2724,13 +2775,19 @@ fn prepare_stmt_replace_params_quick_test(
#[derive(Default)]
struct MockContextProvider {
options: ConfigOptions,
udfs: HashMap<String, Arc<ScalarUDF>>,
udafs: HashMap<String, Arc<AggregateUDF>>,
}

impl MockContextProvider {
fn options_mut(&mut self) -> &mut ConfigOptions {
&mut self.options
}

fn with_udf(mut self, udf: ScalarUDF) -> Self {
self.udfs.insert(udf.name().to_string(), Arc::new(udf));
self
}
}

impl ContextProvider for MockContextProvider {
Expand Down Expand Up @@ -2823,8 +2880,8 @@ impl ContextProvider for MockContextProvider {
}
}

fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> {
None
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.udfs.get(name).map(Arc::clone)
}

fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
Expand Down

0 comments on commit a355d4a

Please sign in to comment.