Skip to content

Commit 140e667

Browse files
timsauceralamb
authored andcommitted
Rebase due to UDF changes upstream
1 parent 356ab99 commit 140e667

File tree

8 files changed

+149
-92
lines changed

8 files changed

+149
-92
lines changed

datafusion/ffi/src/tests/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ use datafusion::{
4242
};
4343
use sync_provider::create_sync_table_provider;
4444
use udf_udaf_udwf::{
45-
create_ffi_abs_func, create_ffi_random_func, create_ffi_rank_func, create_ffi_stddev_func,
46-
create_ffi_sum_func, create_ffi_table_func,
45+
create_ffi_abs_func, create_ffi_random_func, create_ffi_rank_func,
46+
create_ffi_stddev_func, create_ffi_sum_func, create_ffi_table_func,
4747
};
4848

4949
mod async_provider;

datafusion/ffi/src/tests/udf_udaf_udwf.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::{udaf::FFI_AggregateUDF, udf::FFI_ScalarUDF, udtf::FFI_TableFunction, udwf::FFI_WindowUDF};
18+
use crate::{
19+
udaf::FFI_AggregateUDF, udf::FFI_ScalarUDF, udtf::FFI_TableFunction,
20+
udwf::FFI_WindowUDF,
21+
};
1922
use datafusion::{
2023
catalog::TableFunctionImpl,
2124
functions::math::{abs::AbsFunc, random::RandomFunc},
@@ -58,7 +61,13 @@ pub(crate) extern "C" fn create_ffi_stddev_func() -> FFI_AggregateUDF {
5861
}
5962

6063
pub(crate) extern "C" fn create_ffi_rank_func() -> FFI_WindowUDF {
61-
let udwf: Arc<WindowUDF> = Arc::new(Rank::new("rank_demo".to_string(), datafusion::functions_window::rank::RankType::Basic).into());
64+
let udwf: Arc<WindowUDF> = Arc::new(
65+
Rank::new(
66+
"rank_demo".to_string(),
67+
datafusion::functions_window::rank::RankType::Basic,
68+
)
69+
.into(),
70+
);
6271

6372
udwf.into()
6473
}

datafusion/ffi/src/udwf/mod.rs

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@ use arrow::{
2626
compute::SortOptions,
2727
datatypes::{DataType, SchemaRef},
2828
};
29+
use arrow_schema::{Field, FieldRef};
2930
use datafusion::{
3031
error::DataFusionError,
3132
logical_expr::{
32-
function::WindowUDFFieldArgs,
33-
type_coercion::functions::data_types_with_window_udf, PartitionEvaluator,
33+
function::WindowUDFFieldArgs, type_coercion::functions::fields_with_window_udf,
34+
PartitionEvaluator,
3435
},
3536
};
3637
use datafusion::{
@@ -45,6 +46,7 @@ mod partition_evaluator;
4546
mod partition_evaluator_args;
4647
mod range;
4748

49+
use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped};
4850
use crate::{
4951
arrow_wrappers::WrappedSchema,
5052
df_result, rresult, rresult_return,
@@ -130,16 +132,17 @@ unsafe extern "C" fn partition_evaluator_fn_wrapper(
130132

131133
unsafe extern "C" fn field_fn_wrapper(
132134
udwf: &FFI_WindowUDF,
133-
input_types: RVec<WrappedSchema>,
135+
input_fields: RVec<WrappedSchema>,
134136
display_name: RString,
135137
) -> RResult<WrappedSchema, RString> {
136138
let inner = udwf.inner();
137139

138-
let input_types = rresult_return!(rvec_wrapped_to_vec_datatype(&input_types));
140+
let input_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
139141

140-
let field = rresult_return!(
141-
inner.field(WindowUDFFieldArgs::new(&input_types, display_name.as_str()))
142-
);
142+
let field = rresult_return!(inner.field(WindowUDFFieldArgs::new(
143+
&input_fields,
144+
display_name.as_str()
145+
)));
143146

144147
let schema = Arc::new(Schema::new(vec![field]));
145148

@@ -152,9 +155,17 @@ unsafe extern "C" fn coerce_types_fn_wrapper(
152155
) -> RResult<RVec<WrappedSchema>, RString> {
153156
let inner = udwf.inner();
154157

155-
let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
158+
let arg_fields = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types))
159+
.into_iter()
160+
.map(|dt| Field::new("f", dt, false))
161+
.map(Arc::new)
162+
.collect::<Vec<_>>();
156163

157-
let return_types = rresult_return!(data_types_with_window_udf(&arg_types, inner));
164+
let return_fields = rresult_return!(fields_with_window_udf(&arg_fields, inner));
165+
let return_types = return_fields
166+
.into_iter()
167+
.map(|f| f.data_type().to_owned())
168+
.collect::<Vec<_>>();
158169

159170
rresult!(vec_datatype_to_rvec_wrapped(&return_types))
160171
}
@@ -300,9 +311,9 @@ impl WindowUDFImpl for ForeignWindowUDF {
300311
})
301312
}
302313

303-
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<arrow::datatypes::Field> {
314+
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
304315
unsafe {
305-
let input_types = vec_datatype_to_rvec_wrapped(field_args.input_types())?;
316+
let input_types = vec_fieldref_to_rvec_wrapped(field_args.input_fields())?;
306317
let schema = df_result!((self.udf.field)(
307318
&self.udf,
308319
input_types,
@@ -314,7 +325,7 @@ impl WindowUDFImpl for ForeignWindowUDF {
314325
true => Err(DataFusionError::Execution(
315326
"Unable to retrieve field in WindowUDF via FFI".to_string(),
316327
)),
317-
false => Ok(schema.field(0).to_owned()),
328+
false => Ok(schema.field(0).to_owned().into()),
318329
}
319330
}
320331
}

datafusion/ffi/src/udwf/partition_evaluator.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
use std::{ffi::c_void, ops::Range};
1919

20+
use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return};
2021
use abi_stable::{
2122
std_types::{RResult, RString, RVec},
2223
StableAbi,
@@ -29,10 +30,11 @@ use datafusion::{
2930
};
3031
use prost::Message;
3132

32-
use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return};
33-
3433
use super::range::FFI_Range;
3534

35+
/// A stable struct for sharing [`PartitionEvaluator`] across FFI boundaries.
36+
/// For an explanation of each field, see the corresponding function
37+
/// defined in [`PartitionEvaluator`].
3638
#[repr(C)]
3739
#[derive(Debug, StableAbi)]
3840
#[allow(non_camel_case_types)]
@@ -108,7 +110,8 @@ unsafe extern "C" fn evaluate_all_fn_wrapper(
108110
.collect::<Result<Vec<ArrayRef>>>();
109111
let values_arrays = rresult_return!(values_arrays);
110112

111-
let return_array = (inner.evaluate_all(&values_arrays, num_rows))
113+
let return_array = inner
114+
.evaluate_all(&values_arrays, num_rows)
112115
.and_then(|array| WrappedArray::try_from(&array).map_err(DataFusionError::from));
113116

114117
rresult!(return_array)
@@ -148,7 +151,8 @@ unsafe extern "C" fn evaluate_all_with_rank_fn_wrapper(
148151
.map(Range::from)
149152
.collect::<Vec<_>>();
150153

151-
let return_array = (inner.evaluate_all_with_rank(num_rows, &ranks_in_partition))
154+
let return_array = inner
155+
.evaluate_all_with_rank(num_rows, &ranks_in_partition)
152156
.and_then(|array| WrappedArray::try_from(&array).map_err(DataFusionError::from));
153157

154158
rresult!(return_array)

datafusion/ffi/src/udwf/partition_evaluator_args.rs

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
use std::{collections::HashMap, sync::Arc};
1919

20+
use crate::arrow_wrappers::WrappedSchema;
2021
use abi_stable::{std_types::RVec, StableAbi};
2122
use arrow::{
2223
datatypes::{DataType, Field, Schema, SchemaRef},
2324
error::ArrowError,
2425
ffi::FFI_ArrowSchema,
2526
};
27+
use arrow_schema::FieldRef;
2628
use datafusion::{
2729
error::{DataFusionError, Result},
2830
logical_expr::function::PartitionEvaluatorArgs,
@@ -38,51 +40,52 @@ use datafusion_proto::{
3840
};
3941
use prost::Message;
4042

41-
use crate::arrow_wrappers::WrappedSchema;
42-
43+
/// A stable struct for sharing [`PartitionEvaluatorArgs`] across FFI boundaries.
44+
/// For an explanation of each field, see the corresponding function
45+
/// defined in [`PartitionEvaluatorArgs`].
4346
#[repr(C)]
4447
#[derive(Debug, StableAbi)]
4548
#[allow(non_camel_case_types)]
4649
pub struct FFI_PartitionEvaluatorArgs {
4750
input_exprs: RVec<RVec<u8>>,
48-
input_types: RVec<WrappedSchema>,
51+
input_fields: RVec<WrappedSchema>,
4952
is_reversed: bool,
5053
ignore_nulls: bool,
5154
schema: WrappedSchema,
5255
}
5356

54-
5557
impl TryFrom<PartitionEvaluatorArgs<'_>> for FFI_PartitionEvaluatorArgs {
56-
5758
type Error = DataFusionError;
58-
fn try_from(
59-
args: PartitionEvaluatorArgs,
60-
) -> Result<Self, DataFusionError> {
59+
fn try_from(args: PartitionEvaluatorArgs) -> Result<Self, DataFusionError> {
6160
// This is a bit of a hack. Since PartitionEvaluatorArgs does not carry a schema
6261
// around, and instead passes the data types directly we are unable to decode the
6362
// protobuf PhysicalExpr correctly. In evaluating the code the only place these
6463
// appear to be really used are the Column data types. So here we will find all
6564
// of the required columns and create a schema that has empty fields except for
6665
// the ones we require. Ideally we would enhance PartitionEvaluatorArgs to just
6766
// pass along the schema, but that is a larger breaking change.
68-
let required_columns: HashMap<usize, (&str, &DataType)> = args.input_exprs().iter().zip(args.input_types())
69-
.filter_map(|(expr, data_type)| {
70-
expr.as_any().downcast_ref::<Column>().map(|column| (column.index(), (column.name(), data_type)))
67+
let required_columns: HashMap<usize, (&str, &DataType)> = args
68+
.input_exprs()
69+
.iter()
70+
.zip(args.input_fields())
71+
.filter_map(|(expr, field)| {
72+
expr.as_any()
73+
.downcast_ref::<Column>()
74+
.map(|column| (column.index(), (column.name(), field.data_type())))
7175
})
7276
.collect();
7377

7478
let max_column = required_columns.keys().max().unwrap_or(&0).to_owned();
75-
let fields: Vec<_> = (0..max_column).into_iter()
76-
.map(|idx| {
77-
match required_columns.get(&idx) {
78-
Some((name, data_type)) => {
79-
Field::new(*name, (*data_type).clone(), true)
80-
}
81-
None => {
82-
Field::new(format!("ffi_partition_evaluator_col_{idx}"), DataType::Null, true)
83-
}
84-
}
85-
}).collect();
79+
let fields: Vec<_> = (0..max_column)
80+
.map(|idx| match required_columns.get(&idx) {
81+
Some((name, data_type)) => Field::new(*name, (*data_type).clone(), true),
82+
None => Field::new(
83+
format!("ffi_partition_evaluator_col_{idx}"),
84+
DataType::Null,
85+
true,
86+
),
87+
})
88+
.collect();
8689
let schema = Arc::new(Schema::new(fields));
8790

8891
let codec = DefaultPhysicalExtensionCodec {};
@@ -91,8 +94,8 @@ impl TryFrom<PartitionEvaluatorArgs<'_>> for FFI_PartitionEvaluatorArgs {
9194
.map(|expr_node| expr_node.encode_to_vec().into())
9295
.collect();
9396

94-
let input_types = args
95-
.input_types()
97+
let input_fields = args
98+
.input_fields()
9699
.iter()
97100
.map(|input_type| FFI_ArrowSchema::try_from(input_type).map(WrappedSchema))
98101
.collect::<Result<Vec<_>, ArrowError>>()?
@@ -102,7 +105,7 @@ impl TryFrom<PartitionEvaluatorArgs<'_>> for FFI_PartitionEvaluatorArgs {
102105

103106
Ok(Self {
104107
input_exprs,
105-
input_types,
108+
input_fields,
106109
schema,
107110
is_reversed: args.is_reversed(),
108111
ignore_nulls: args.ignore_nulls(),
@@ -116,7 +119,7 @@ impl TryFrom<PartitionEvaluatorArgs<'_>> for FFI_PartitionEvaluatorArgs {
116119
/// PartitionEvaluatorArgs can then reference.
117120
pub struct ForeignPartitionEvaluatorArgs {
118121
input_exprs: Vec<Arc<dyn PhysicalExpr>>,
119-
input_types: Vec<DataType>,
122+
input_fields: Vec<FieldRef>,
120123
is_reversed: bool,
121124
ignore_nulls: bool,
122125
}
@@ -142,14 +145,14 @@ impl TryFrom<FFI_PartitionEvaluatorArgs> for ForeignPartitionEvaluatorArgs {
142145
})
143146
.collect::<Result<Vec<_>>>()?;
144147

145-
let input_types = input_exprs
148+
let input_fields = input_exprs
146149
.iter()
147-
.map(|expr| expr.data_type(&schema))
150+
.map(|expr| expr.return_field(&schema))
148151
.collect::<Result<Vec<_>>>()?;
149152

150153
Ok(Self {
151154
input_exprs,
152-
input_types,
155+
input_fields,
153156
is_reversed: value.is_reversed,
154157
ignore_nulls: value.ignore_nulls,
155158
})
@@ -160,7 +163,7 @@ impl<'a> From<&'a ForeignPartitionEvaluatorArgs> for PartitionEvaluatorArgs<'a>
160163
fn from(value: &'a ForeignPartitionEvaluatorArgs) -> Self {
161164
PartitionEvaluatorArgs::new(
162165
&value.input_exprs,
163-
&value.input_types,
166+
&value.input_fields,
164167
value.is_reversed,
165168
value.ignore_nulls,
166169
)

datafusion/ffi/src/udwf/range.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ use std::ops::Range;
1919

2020
use abi_stable::StableAbi;
2121

22+
/// A stable struct for sharing [`Range`] across FFI boundaries.
23+
/// For an explanation of each field, see the corresponding function
24+
/// defined in [`Range`].
2225
#[repr(C)]
2326
#[derive(Debug, StableAbi)]
2427
#[allow(non_camel_case_types)]

datafusion/ffi/tests/ffi_integration.rs

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,11 @@
2020
#[cfg(feature = "integration-tests")]
2121
mod tests {
2222
use datafusion::error::{DataFusionError, Result};
23-
use datafusion::prelude::{col, SessionContext};
23+
use datafusion::prelude::SessionContext;
2424
use datafusion_ffi::catalog_provider::ForeignCatalogProvider;
2525
use datafusion_ffi::table_provider::ForeignTableProvider;
2626
use datafusion_ffi::tests::create_record_batch;
2727
use datafusion_ffi::tests::utils::get_module;
28-
use datafusion::logical_expr::{ScalarUDF, WindowUDF};
29-
use datafusion_ffi::udf::ForeignScalarUDF;
30-
use datafusion_ffi::udwf::ForeignWindowUDF;
31-
use std::path::Path;
3228
use std::sync::Arc;
3329

3430
/// It is important that this test is in the `tests` directory and not in the
@@ -99,41 +95,4 @@ mod tests {
9995

10096
Ok(())
10197
}
102-
103-
#[tokio::test]
104-
async fn test_rank_udwf() -> Result<()> {
105-
let module = get_module()?;
106-
107-
let ffi_rank_func =
108-
module
109-
.create_rank_udwf()
110-
.ok_or(DataFusionError::NotImplemented(
111-
"External table provider failed to implement create_scalar_udf"
112-
.to_string(),
113-
))?();
114-
let foreign_rank_func: ForeignWindowUDF = (&ffi_rank_func).try_into()?;
115-
116-
let udwf: WindowUDF = foreign_rank_func.into();
117-
118-
let ctx = SessionContext::default();
119-
let df = ctx.read_batch(create_record_batch(-5, 5))?;
120-
121-
let df = df
122-
.with_column("rank_a", udwf.call(vec![]))?;
123-
124-
let result = df.collect().await?;
125-
126-
let expected = record_batch!(
127-
("a", Int32, vec![-5, -4, -3, -2, -1]),
128-
("b", Float64, vec![-5., -4., -3., -2., -1.]),
129-
("abs_a", Int32, vec![5, 4, 3, 2, 1]),
130-
("abs_b", Float64, vec![5., 4., 3., 2., 1.])
131-
)?;
132-
133-
assert!(result.len() == 1);
134-
assert!(result[0] == expected);
135-
136-
Ok(())
137-
138-
}
13998
}

0 commit comments

Comments
 (0)