Skip to content

Commit

Permalink
Remove Arc wrapping from create_udf's return_type (#12489)
Browse files Browse the repository at this point in the history
The argument types are moved into `create_udf` so moving also
`return_type` would increase API consistency.

Internally, the `create_udf` unwrapped or cloned (so moves) the passed in
return type Arc, so there was no non-API benefit from using a shared
pointer.
  • Loading branch information
findepi authored Sep 17, 2024
1 parent be42f3d commit 0a64f34
Show file tree
Hide file tree
Showing 10 changed files with 15 additions and 16 deletions.
2 changes: 1 addition & 1 deletion datafusion-examples/examples/simple_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ async fn main() -> Result<()> {
// expects two f64
vec![DataType::Float64, DataType::Float64],
// returns f64
Arc::new(DataType::Float64),
DataType::Float64,
Volatility::Immutable,
pow,
);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2772,7 +2772,7 @@ mod tests {
ctx.register_udf(create_udf(
"my_fn",
vec![DataType::Float64],
Arc::new(DataType::Float64),
DataType::Float64,
Volatility::Immutable,
my_fn,
));
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/expr_api/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ fn test_evaluate(input_expr: Expr, expected_expr: Expr) {
// Make a UDF that adds its two values together, with the specified volatility
fn make_udf_add(volatility: Volatility) -> Arc<ScalarUDF> {
let input_types = vec![DataType::Int32, DataType::Int32];
let return_type = Arc::new(DataType::Int32);
let return_type = DataType::Int32;

let fun = Arc::new(|args: &[ColumnarValue]| {
let args = ColumnarValue::values_to_arrays(args)?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ async fn scalar_udf() -> Result<()> {
ctx.register_udf(create_udf(
"my_add",
vec![DataType::Int32, DataType::Int32],
Arc::new(DataType::Int32),
DataType::Int32,
Volatility::Immutable,
myfunc,
));
Expand Down Expand Up @@ -237,7 +237,7 @@ async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> {
ctx.register_udf(create_udf(
"buggy_func",
vec![DataType::Int32],
Arc::new(DataType::Int32),
DataType::Int32,
Volatility::Immutable,
buggy_udf,
));
Expand Down Expand Up @@ -321,7 +321,7 @@ async fn scalar_udf_override_built_in_scalar_function() -> Result<()> {
ctx.register_udf(create_udf(
"abs",
vec![DataType::Int32],
Arc::new(DataType::Int32),
DataType::Int32,
Volatility::Immutable,
Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(1))))),
));
Expand Down Expand Up @@ -414,7 +414,7 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
ctx.register_udf(create_udf(
"MY_FUNC",
vec![DataType::Int32],
Arc::new(DataType::Int32),
DataType::Int32,
Volatility::Immutable,
myfunc,
));
Expand Down Expand Up @@ -459,7 +459,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
let udf = create_udf(
"dummy",
vec![DataType::Int32],
Arc::new(DataType::Int32),
DataType::Int32,
Volatility::Immutable,
myfunc,
)
Expand Down Expand Up @@ -1149,7 +1149,7 @@ fn create_udf_context() -> SessionContext {
ctx.register_udf(create_udf(
"custom_sqrt",
vec![DataType::Float64],
Arc::new(DataType::Float64),
DataType::Float64,
Volatility::Immutable,
Arc::new(custom_sqrt),
));
Expand Down
3 changes: 1 addition & 2 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,10 @@ pub fn unnest(expr: Expr) -> Expr {
pub fn create_udf(
name: &str,
input_types: Vec<DataType>,
return_type: Arc<DataType>,
return_type: DataType,
volatility: Volatility,
fun: ScalarFunctionImplementation,
) -> ScalarUDF {
let return_type = Arc::unwrap_or_clone(return_type);
ScalarUDF::from(SimpleScalarUDF::new(
name,
input_types,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/src/bytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl Serializeable for Expr {
Ok(Arc::new(create_udf(
name,
vec![],
Arc::new(arrow::datatypes::DataType::Null),
arrow::datatypes::DataType::Null,
Volatility::Immutable,
Arc::new(|_| unimplemented!()),
)))
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2172,7 +2172,7 @@ fn roundtrip_scalar_udf() {
let udf = create_udf(
"dummy",
vec![DataType::Utf8],
Arc::new(DataType::Utf8),
DataType::Utf8,
Volatility::Immutable,
scalar_fn,
);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ fn roundtrip_scalar_udf() -> Result<()> {
let udf = create_udf(
"dummy",
vec![DataType::Int64],
Arc::new(DataType::Int64),
DataType::Int64,
Volatility::Immutable,
scalar_fn.clone(),
);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/tests/cases/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ fn context_with_udf() -> SessionContext {
let udf = create_udf(
"dummy",
vec![DataType::Utf8],
Arc::new(DataType::Utf8),
DataType::Utf8,
Volatility::Immutable,
scalar_fn,
);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ fn create_example_udf() -> ScalarUDF {
// Expects two f64 values:
vec![DataType::Float64, DataType::Float64],
// Returns an f64 value:
Arc::new(DataType::Float64),
DataType::Float64,
Volatility::Immutable,
adder,
)
Expand Down

0 comments on commit 0a64f34

Please sign in to comment.