Skip to content

Commit

Permalink
Add datafusion_typeof function
Browse files Browse the repository at this point in the history
Add a function similar to arrow_typeof, for inspecting logical type as
seen by DataFusion.
  • Loading branch information
findepi committed Aug 30, 2024
1 parent e603185 commit 06a00f7
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 2 deletions.
31 changes: 31 additions & 0 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,34 @@ async fn test_parameter_invalid_types() -> Result<()> {
);
Ok(())
}

#[tokio::test]
async fn test_datafusion_typeof() -> Result<()> {
assert_scalar("datafusion_typeof(NULL)", &ScalarValue::from("Null")).await?;
assert_scalar("datafusion_typeof(42)", &ScalarValue::from("Int64")).await?;
assert_scalar("datafusion_typeof('abc')", &ScalarValue::from("Utf8")).await?;

Ok(())
}

async fn assert_scalar(expression: impl Into<&str>, expected: &ScalarValue) -> Result<()> {
let ctx = SessionContext::new();
let results = execute_to_batches(&ctx, format!("SELECT {} AS t", expression.into()).as_str()).await;
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_columns(), 1);
let column = batch.column(0);
assert_eq!(column.len(), 1);
match column.data_type() {
DataType::Utf8 => {
let value = column.as_string::<i32>().value(0);
match expected {
ScalarValue::Utf8(Some(expected)) => assert_eq!(value, expected),
_ => panic!("unsupported expected scalar type"),
}
}
_ => panic!("unsupported array data type"),
}

Ok(())
}
7 changes: 7 additions & 0 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::sync::Arc;
pub mod arrow_cast;
pub mod arrowtypeof;
pub mod coalesce;
pub mod dftypeof;
pub mod expr_ext;
pub mod getfield;
pub mod named_struct;
Expand All @@ -38,6 +39,7 @@ make_udf_function!(nullif::NullIfFunc, NULLIF, nullif);
make_udf_function!(nvl::NVLFunc, NVL, nvl);
make_udf_function!(nvl2::NVL2Func, NVL2, nvl2);
make_udf_function!(arrowtypeof::ArrowTypeOfFunc, ARROWTYPEOF, arrow_typeof);
make_udf_function!(dftypeof::DataFusionTypeOfFunc, DFTYPEOF, datafusion_typeof);
make_udf_function!(r#struct::StructFunc, STRUCT, r#struct);
make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct);
make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field);
Expand Down Expand Up @@ -66,6 +68,10 @@ pub mod expr_fn {
arrow_typeof,
"Returns the Arrow type of the input expression.",
arg1
),(
datafusion_typeof,
"Returns the DataFusion type of the input expression.",
arg1
),(
r#struct,
"Returns a struct with the given arguments",
Expand Down Expand Up @@ -94,6 +100,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
nvl(),
nvl2(),
arrow_typeof(),
datafusion_typeof(),
named_struct(),
// Note: most users invoke `get_field` indirectly via field access
// syntax like `my_struct_col['field_name']`, which results in a call to
Expand Down
15 changes: 13 additions & 2 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarV
use datafusion_expr::expr::{InList, InSubquery, WindowFunction};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility,
WindowFunctionDefinition,
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Like, Operator,
Volatility, WindowFunctionDefinition,
};
use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
Expand Down Expand Up @@ -625,6 +625,17 @@ impl<'a> ConstEvaluator<'a> {
if let Expr::Literal(s) = expr {
return ConstSimplifyResult::NotSimplified(s);
}
if let Expr::ScalarFunction(ScalarFunction { func, args }) = &expr {
// This function cannot evaluate at runtime since DF types are elided.
// TODO the function is provided by core and ideally this logic would be provided by it too
if func.name() == "datafusion_typeof" && args.len() == 1 {
if let Ok(data_type) = args[0].get_type(&self.input_schema) {
return ConstSimplifyResult::Simplified(ScalarValue::from(
data_type.to_string(),
));
}
}
}

let phys_expr =
match create_physical_expr(&expr, &self.input_schema, self.execution_props) {
Expand Down

0 comments on commit 06a00f7

Please sign in to comment.