diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 7806461bb1ac..2ffac6a775d7 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -30,7 +30,7 @@ use datafusion::error::Result; use datafusion::prelude::*; use datafusion::assert_batches_eq; -use datafusion_common::DFSchema; +use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::ExprSchemable; @@ -161,6 +161,181 @@ async fn test_fn_btrim_with_chars() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_fn_nullif() -> Result<()> { + let expr = nullif(col("a"), lit("abcDEF")); + + let expected = [ + "+-------------------------------+", + "| nullif(test.a,Utf8(\"abcDEF\")) |", + "+-------------------------------+", + "| |", + "| abc123 |", + "| CBAdef |", + "| 123AbcDef |", + "+-------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_arrow_cast() -> Result<()> { + let expr = arrow_typeof(arrow_cast(col("b"), lit("Float64"))); + + let expected = [ + "+--------------------------------------------------+", + "| arrow_typeof(arrow_cast(test.b,Utf8(\"Float64\"))) |", + "+--------------------------------------------------+", + "| Float64 |", + "| Float64 |", + "| Float64 |", + "| Float64 |", + "+--------------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_nvl() -> Result<()> { + let lit_null = lit(ScalarValue::Utf8(None)); + // nvl(CASE WHEN a = 'abcDEF' THEN NULL ELSE a END, 'TURNED_NULL') + let expr = nvl( + when(col("a").eq(lit("abcDEF")), lit_null) + .otherwise(col("a")) + .unwrap(), + lit("TURNED_NULL"), + ) + .alias("nvl_expr"); + + let expected = [ + "+-------------+", + "| nvl_expr |", + "+-------------+", + "| TURNED_NULL |", + "| abc123 |", + "| CBAdef |", + "| 123AbcDef |", + "+-------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} +#[tokio::test] +async fn test_nvl2() -> Result<()> { + let lit_null = lit(ScalarValue::Utf8(None)); + // nvl2(CASE WHEN a = 'abcDEF' THEN NULL ELSE a END, 'NON_NUll', 'TURNED_NULL') + let expr = nvl2( + when(col("a").eq(lit("abcDEF")), lit_null) + .otherwise(col("a")) + .unwrap(), + lit("NON_NULL"), + lit("TURNED_NULL"), + ) + .alias("nvl2_expr"); + + let expected = [ + "+-------------+", + "| nvl2_expr |", + "+-------------+", + "| TURNED_NULL |", + "| NON_NULL |", + "| NON_NULL |", + "| NON_NULL |", + "+-------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} +#[tokio::test] +async fn test_fn_arrow_typeof() -> Result<()> { + let expr = arrow_typeof(col("l")); + + let expected = [ + "+------------------------------------------------------------------------------------------------------------------+", + "| arrow_typeof(test.l) |", + "+------------------------------------------------------------------------------------------------------------------+", + "| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |", + "| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |", + "| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |", + "| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |", + "+------------------------------------------------------------------------------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_struct() -> Result<()> { + let expr = r#struct(vec![col("a"), col("b")]); + + let expected = [ + "+--------------------------+", + "| struct(test.a,test.b) |", + "+--------------------------+", + "| {c0: abcDEF, c1: 1} |", + "| {c0: abc123, c1: 10} |", + "| {c0: CBAdef, c1: 10} |", + "| {c0: 123AbcDef, c1: 100} |", + "+--------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_named_struct() -> Result<()> { + let expr = named_struct(vec![lit("column_a"), col("a"), lit("column_b"), col("b")]); + + let expected = [ + "+---------------------------------------------------------------+", + "| named_struct(Utf8(\"column_a\"),test.a,Utf8(\"column_b\"),test.b) |", + "+---------------------------------------------------------------+", + "| {column_a: abcDEF, column_b: 1} |", + "| {column_a: abc123, column_b: 10} |", + "| {column_a: CBAdef, column_b: 10} |", + "| {column_a: 123AbcDef, column_b: 100} |", + "+---------------------------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_coalesce() -> Result<()> { + let expr = coalesce(vec![lit(ScalarValue::Utf8(None)), lit("ab")]); + + let expected = [ + "+---------------------------------+", + "| coalesce(Utf8(NULL),Utf8(\"ab\")) |", + "+---------------------------------+", + "| ab |", + "| ab |", + "| ab |", + "| ab |", + "+---------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + #[tokio::test] async fn test_fn_approx_median() -> Result<()> { let expr = approx_median(col("b")); diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 753134bdfdc2..d60e6017ddcb 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -17,6 +17,9 @@ //! "core" DataFusion functions +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + pub mod arrow_cast; pub mod arrowtypeof; pub mod coalesce; @@ -39,14 +42,68 @@ make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); // Export the functions out of this package, both as expr_fn as well as a list of functions -export_functions!( - (nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."), - (arrow_cast, arg_1 arg_2, "returns arg_1 cast to the `arrow_type` given the second argument. This can be used to cast to a specific `arrow_type`."), - (nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1"), - (nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."), - (arrow_typeof, arg_1, "Returns the Arrow type of the input expression."), - (r#struct, args, "Returns a struct with the given arguments"), - (named_struct, args, "Returns a struct with the given names and arguments pairs"), - (get_field, arg_1 arg_2, "Returns the value of the field with the given name from the struct"), - (coalesce, args, "Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL") -); +pub mod expr_fn { + use datafusion_expr::Expr; + + /// returns NULL if value1 equals value2; otherwise it returns value1. This + /// can be used to perform the inverse operation of the COALESCE expression + pub fn nullif(arg1: Expr, arg2: Expr) -> Expr { + super::nullif().call(vec![arg1, arg2]) + } + + /// returns value1 cast to the `arrow_type` given the second argument. This + /// can be used to cast to a specific `arrow_type`. + pub fn arrow_cast(arg1: Expr, arg2: Expr) -> Expr { + super::arrow_cast().call(vec![arg1, arg2]) + } + + /// Returns value2 if value1 is NULL; otherwise it returns value1 + pub fn nvl(arg1: Expr, arg2: Expr) -> Expr { + super::nvl().call(vec![arg1, arg2]) + } + + /// Returns value2 if value1 is not NULL; otherwise, it returns value3. + pub fn nvl2(arg1: Expr, arg2: Expr, arg3: Expr) -> Expr { + super::nvl2().call(vec![arg1, arg2, arg3]) + } + + /// Returns the Arrow type of the input expression. + pub fn arrow_typeof(arg1: Expr) -> Expr { + super::arrow_typeof().call(vec![arg1]) + } + + /// Returns a struct with the given arguments + pub fn r#struct(args: Vec) -> Expr { + super::r#struct().call(args) + } + + /// Returns a struct with the given names and arguments pairs + pub fn named_struct(args: Vec) -> Expr { + super::named_struct().call(args) + } + + /// Returns the value of the field with the given name from the struct + pub fn get_field(arg1: Expr, arg2: Expr) -> Expr { + super::get_field().call(vec![arg1, arg2]) + } + + /// Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL + pub fn coalesce(args: Vec) -> Expr { + super::coalesce().call(args) + } +} + +/// Return a list of all functions in this package +pub fn functions() -> Vec> { + vec![ + nullif(), + arrow_cast(), + nvl(), + nvl2(), + arrow_typeof(), + r#struct(), + named_struct(), + get_field(), + coalesce(), + ] +}