Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ name = "math_query_sql"
harness = false
name = "filter_query_sql"

[[bench]]
harness = false
name = "struct_query_sql"

[[bench]]
harness = false
name = "window_query_sql"
Expand Down
79 changes: 79 additions & 0 deletions datafusion/core/benches/struct_query_sql.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::{
array::{Float32Array, Float64Array},
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
};
use criterion::{criterion_group, criterion_main, Criterion};
use datafusion::prelude::SessionContext;
use datafusion::{datasource::MemTable, error::Result};
use futures::executor::block_on;
use std::sync::Arc;
use tokio::runtime::Runtime;

async fn query(ctx: &SessionContext, sql: &str) {
let rt = Runtime::new().unwrap();

// execute the query
let df = rt.block_on(ctx.sql(sql)).unwrap();
criterion::black_box(rt.block_on(df.collect()).unwrap());
}

fn create_context(array_len: usize, batch_size: usize) -> Result<SessionContext> {
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("f32", DataType::Float32, false),
Field::new("f64", DataType::Float64, false),
]));

// define data.
let batches = (0..array_len / batch_size)
.map(|i| {
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![i as f32; batch_size])),
Arc::new(Float64Array::from(vec![i as f64; batch_size])),
],
)
.unwrap()
})
.collect::<Vec<_>>();

let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![batches])?;
ctx.register_table("t", Arc::new(provider))?;

Ok(ctx)
}

fn criterion_benchmark(c: &mut Criterion) {
let array_len = 524_288; // 2^19
let batch_size = 4096; // 2^12

c.bench_function("struct", |b| {
let ctx = create_context(array_len, batch_size).unwrap();
b.iter(|| block_on(query(&ctx, "select struct(f32, f64) from t")))
});
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
100 changes: 26 additions & 74 deletions datafusion/functions/src/core/named_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,79 +17,15 @@

use arrow::array::StructArray;
use arrow::datatypes::{DataType, Field, Fields};
use datafusion_common::{exec_err, internal_err, HashSet, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs};
use datafusion_common::{exec_err, internal_err, Result};
use datafusion_expr::{
ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs,
};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;
use std::any::Any;
use std::sync::Arc;

/// Put values in a struct array.
fn named_struct_expr(args: &[ColumnarValue]) -> Result<ColumnarValue> {
// Do not accept 0 arguments.
if args.is_empty() {
return exec_err!(
"named_struct requires at least one pair of arguments, got 0 instead"
);
}

if args.len() % 2 != 0 {
return exec_err!(
"named_struct requires an even number of arguments, got {} instead",
args.len()
);
}

let (names, values): (Vec<_>, Vec<_>) = args
.chunks_exact(2)
.enumerate()
.map(|(i, chunk)| {
let name_column = &chunk[0];
let name = match name_column {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => {
name_scalar
}
// TODO: Implement Display for ColumnarValue
_ => {
return exec_err!(
"named_struct even arguments must be string literals at position {}",
i * 2
)
}
};

Ok((name, chunk[1].clone()))
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.unzip();

{
// Check to enforce the uniqueness of struct field name
let mut unique_field_names = HashSet::new();
for name in names.iter() {
if unique_field_names.contains(name) {
return exec_err!(
"named_struct requires unique field names. Field {name} is used more than once."
);
}
unique_field_names.insert(name);
}
}

let fields: Fields = names
.into_iter()
.zip(&values)
.map(|(name, value)| Arc::new(Field::new(name, value.data_type().clone(), true)))
.collect::<Vec<_>>()
.into();

let arrays = ColumnarValue::values_to_arrays(&values)?;

let struct_array = StructArray::new(fields, arrays, None);
Ok(ColumnarValue::Array(Arc::new(struct_array)))
}

#[user_doc(
doc_section(label = "Struct Functions"),
description = "Returns an Arrow struct using the specified name and input expressions pairs.",
Expand Down Expand Up @@ -203,12 +139,28 @@ impl ScalarUDFImpl for NamedStructFunc {
))))
}

fn invoke_batch(
Copy link

@lichuang lichuang Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If delete invoke_batch function, the default invoke_batch function will call invoke when args is not empty, but invoke function is left impl by Specific type, is it ok?

Copy link
Contributor Author

@pepijnve pepijnve Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure what the intention is for the deprecated functions from an API user point of view. I had interpreted ScalarUDF(Impl) as an extension point rather than as public API that library users would call directly. If that assumption is correct, then it doesn't matter if we remove the invoke_batch implementation since ScalarFunctionExpr never calls this method. The only usages of invoke_batch are in benchmark and test code.
Perhaps @alamb, who logged #13515, can provide some insight here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THe reason we kept around deprecated functions is to give downstream users of DataFusion time to adjust their code -- especially on upgrade having a deprecated function with guidance of what to change has been helpful

You can read more about this strategy here:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm coming from the Java world so I'll use the terminology from there; not sure what the equivalent is in Rust lingo. It's not explicitly stated which parts of the library are Service Provider Interface vs Application Programming Interface. My assumption was that ScalarUDFImpl is SPI and invoke_batch was kept around to not break all existing implementations and ScalarFunctionExpr is the API side of things which doesn't expose invoke_batch. Is that a correct interpretation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My assumption was that ScalarUDFImpl is SPI

Yes, that is accurate in my understanding

and invoke_batch was kept around to not break all existing implementations

Yes that is also my understanding

and ScalarFunctionExpr is the API side of things which doesn't expose invoke_batch. Is that a correct interpretation?

I would say ScalarFunctionExpr is an implementation detail of how functions are invoked in the ExecutionPlan (aka the physical execution)

The split between logical/physical plans is explained a bit in the API docs / intro videos in case you are interested:
https://docs.rs/datafusion/latest/datafusion/index.html#query-planning-and-execution-overview

&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
named_struct_expr(args)
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let DataType::Struct(fields) = args.return_type else {
return internal_err!("incorrect named_struct return type");
};

assert_eq!(
fields.len(),
args.args.len() / 2,
"return type field count != argument count / 2"
);

let values: Vec<ColumnarValue> = args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😍

.args
.chunks_exact(2)
.map(|chunk| chunk[1].clone())
.collect();
let arrays = ColumnarValue::values_to_arrays(&values)?;
Ok(ColumnarValue::Array(Arc::new(StructArray::new(
fields.clone(),
arrays,
None,
))))
}

fn documentation(&self) -> Option<&Documentation> {
Expand Down
74 changes: 30 additions & 44 deletions datafusion/functions/src/core/struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,15 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{ArrayRef, StructArray};
use arrow::datatypes::{DataType, Field, Fields};
use datafusion_common::{exec_err, Result};
use datafusion_expr::{ColumnarValue, Documentation};
use arrow::array::StructArray;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{exec_err, internal_err, Result};
use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;
use std::any::Any;
use std::sync::Arc;

fn array_struct(args: &[ArrayRef]) -> Result<ArrayRef> {
// do not accept 0 arguments.
if args.is_empty() {
return exec_err!("struct requires at least one argument");
}

let fields = args
.iter()
.enumerate()
.map(|(i, arg)| {
let field_name = format!("c{i}");
Ok(Arc::new(Field::new(
field_name.as_str(),
arg.data_type().clone(),
true,
)))
})
.collect::<Result<Vec<_>>>()?
.into();

let arrays = args.to_vec();

Ok(Arc::new(StructArray::new(fields, arrays, None)))
}

/// put values in a struct array.
fn struct_expr(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let arrays = ColumnarValue::values_to_arrays(args)?;
Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?))
}

#[user_doc(
doc_section(label = "Struct Functions"),
description = "Returns an Arrow struct using the specified input expressions optionally named.
Expand Down Expand Up @@ -133,20 +102,37 @@ impl ScalarUDFImpl for StructFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let return_fields = arg_types
if arg_types.is_empty() {
return exec_err!("struct requires at least one argument, got 0 instead");
}

let fields = arg_types
.iter()
.enumerate()
.map(|(pos, dt)| Field::new(format!("c{pos}"), dt.clone(), true))
.collect::<Vec<Field>>();
Ok(DataType::Struct(Fields::from(return_fields)))
.collect::<Vec<Field>>()
.into();

Ok(DataType::Struct(fields))
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
struct_expr(args)
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let DataType::Struct(fields) = args.return_type else {
return internal_err!("incorrect struct return type");
};

assert_eq!(
fields.len(),
args.args.len(),
"return type field count != argument count"
);

let arrays = ColumnarValue::values_to_arrays(&args.args)?;
Ok(ColumnarValue::Array(Arc::new(StructArray::new(
fields.clone(),
arrays,
None,
))))
}

fn documentation(&self) -> Option<&Documentation> {
Expand Down