Skip to content

Commit 32ea724

Browse files
committed
Work in progress on integration testing of udwf
1 parent 67419ac commit 32ea724

File tree

5 files changed

+96
-32
lines changed

5 files changed

+96
-32
lines changed

datafusion/ffi/src/tests/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ use crate::{catalog_provider::FFI_CatalogProvider, udtf::FFI_TableFunction};
3131

3232
use crate::udaf::FFI_AggregateUDF;
3333

34+
use crate::udwf::FFI_WindowUDF;
35+
3436
use super::{table_provider::FFI_TableProvider, udf::FFI_ScalarUDF};
3537
use arrow::array::RecordBatch;
3638
use async_provider::create_async_table_provider;
@@ -40,7 +42,7 @@ use datafusion::{
4042
};
4143
use sync_provider::create_sync_table_provider;
4244
use udf_udaf_udwf::{
43-
create_ffi_abs_func, create_ffi_random_func, create_ffi_stddev_func,
45+
create_ffi_abs_func, create_ffi_random_func, create_ffi_rank_func, create_ffi_stddev_func,
4446
create_ffi_sum_func, create_ffi_table_func,
4547
};
4648

@@ -76,6 +78,8 @@ pub struct ForeignLibraryModule {
7678
/// Createa grouping UDAF using stddev
7779
pub create_stddev_udaf: extern "C" fn() -> FFI_AggregateUDF,
7880

81+
pub create_rank_udwf: extern "C" fn() -> FFI_WindowUDF,
82+
7983
pub version: extern "C" fn() -> u64,
8084
}
8185

@@ -125,6 +129,7 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef {
125129
create_table_function: create_ffi_table_func,
126130
create_sum_udaf: create_ffi_sum_func,
127131
create_stddev_udaf: create_ffi_stddev_func,
132+
create_rank_udwf: create_ffi_rank_func,
128133
version: super::version,
129134
}
130135
.leak_into_prefix()

datafusion/ffi/src/tests/udf_udaf_udwf.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::{udaf::FFI_AggregateUDF, udf::FFI_ScalarUDF, udtf::FFI_TableFunction};
18+
use crate::{udaf::FFI_AggregateUDF, udf::FFI_ScalarUDF, udtf::FFI_TableFunction, udwf::FFI_WindowUDF};
1919
use datafusion::{
2020
catalog::TableFunctionImpl,
2121
functions::math::{abs::AbsFunc, random::RandomFunc},
2222
functions_aggregate::{stddev::Stddev, sum::Sum},
2323
functions_table::generate_series::RangeFunc,
24-
logical_expr::{AggregateUDF, ScalarUDF},
24+
functions_window::rank::Rank,
25+
logical_expr::{AggregateUDF, ScalarUDF, WindowUDF},
2526
};
2627

2728
use std::sync::Arc;
@@ -55,3 +56,9 @@ pub(crate) extern "C" fn create_ffi_stddev_func() -> FFI_AggregateUDF {
5556

5657
udaf.into()
5758
}
59+
60+
pub(crate) extern "C" fn create_ffi_rank_func() -> FFI_WindowUDF {
61+
let udwf: Arc<WindowUDF> = Arc::new(Rank::new("rank_demo".to_string(), datafusion::functions_window::rank::RankType::Basic).into());
62+
63+
udwf.into()
64+
}

datafusion/ffi/src/udwf/mod.rs

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@ pub struct FFI_WindowUDF {
8787
arg_types: RVec<WrappedSchema>,
8888
) -> RResult<RVec<WrappedSchema>, RString>,
8989

90-
pub schema: unsafe extern "C" fn(udwf: &Self) -> WrappedSchema,
91-
9290
pub sort_options: ROption<FFI_SortOptions>,
9391

9492
/// Used to create a clone on the provider of the udf. This should
@@ -108,19 +106,13 @@ unsafe impl Sync for FFI_WindowUDF {}
108106

109107
pub struct WindowUDFPrivateData {
110108
pub udf: Arc<WindowUDF>,
111-
pub schema: SchemaRef,
112109
}
113110

114111
impl FFI_WindowUDF {
115112
unsafe fn inner(&self) -> &Arc<WindowUDF> {
116113
let private_data = self.private_data as *const WindowUDFPrivateData;
117114
&(*private_data).udf
118115
}
119-
120-
unsafe fn inner_schema(&self) -> SchemaRef {
121-
let private_data = self.private_data as *const WindowUDFPrivateData;
122-
Arc::clone(&(*private_data).schema)
123-
}
124116
}
125117

126118
unsafe extern "C" fn partition_evaluator_fn_wrapper(
@@ -167,12 +159,6 @@ unsafe extern "C" fn coerce_types_fn_wrapper(
167159
rresult!(vec_datatype_to_rvec_wrapped(&return_types))
168160
}
169161

170-
unsafe extern "C" fn schema_fn_wrapper(udwf: &FFI_WindowUDF) -> WrappedSchema {
171-
let schema = udwf.inner_schema();
172-
173-
schema.into()
174-
}
175-
176162
unsafe extern "C" fn release_fn_wrapper(udwf: &mut FFI_WindowUDF) {
177163
let private_data = Box::from_raw(udwf.private_data as *mut WindowUDFPrivateData);
178164
drop(private_data);
@@ -187,7 +173,6 @@ unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF {
187173
// });
188174
let private_data = Box::new(WindowUDFPrivateData {
189175
udf: Arc::clone(udwf.inner()),
190-
schema: udwf.inner_schema(),
191176
});
192177

193178
FFI_WindowUDF {
@@ -196,7 +181,6 @@ unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF {
196181
volatility: udwf.volatility.clone(),
197182
partition_evaluator: partition_evaluator_fn_wrapper,
198183
sort_options: udwf.sort_options.clone(),
199-
schema: schema_fn_wrapper,
200184
coerce_types: coerce_types_fn_wrapper,
201185
field: field_fn_wrapper,
202186
clone: clone_fn_wrapper,
@@ -211,22 +195,21 @@ impl Clone for FFI_WindowUDF {
211195
}
212196
}
213197

214-
impl FFI_WindowUDF {
215-
pub fn new(udf: Arc<WindowUDF>, schema: SchemaRef) -> Self {
198+
impl From<Arc<WindowUDF>> for FFI_WindowUDF {
199+
fn from(udf: Arc<WindowUDF>) -> Self {
216200
let name = udf.name().into();
217201
let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
218202
let volatility = udf.signature().volatility.into();
219203
let sort_options = udf.sort_options().map(|v| (&v).into()).into();
220204

221-
let private_data = Box::new(WindowUDFPrivateData { udf, schema });
205+
let private_data = Box::new(WindowUDFPrivateData { udf });
222206

223207
Self {
224208
name,
225209
aliases,
226210
volatility,
227211
partition_evaluator: partition_evaluator_fn_wrapper,
228212
sort_options,
229-
schema: schema_fn_wrapper,
230213
coerce_types: coerce_types_fn_wrapper,
231214
field: field_fn_wrapper,
232215
clone: clone_fn_wrapper,
@@ -307,8 +290,7 @@ impl WindowUDFImpl for ForeignWindowUDF {
307290
args: datafusion::logical_expr::function::PartitionEvaluatorArgs,
308291
) -> Result<Box<dyn PartitionEvaluator>> {
309292
let evaluator = unsafe {
310-
let schema = (self.udf.schema)(&self.udf);
311-
let args = FFI_PartitionEvaluatorArgs::new(args, schema.into())?;
293+
let args = FFI_PartitionEvaluatorArgs::try_from(args)?;
312294
(self.udf.partition_evaluator)(&self.udf, args)
313295
};
314296

datafusion/ffi/src/udwf/partition_evaluator_args.rs

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::sync::Arc;
18+
use std::{collections::HashMap, sync::Arc};
1919

2020
use abi_stable::{std_types::RVec, StableAbi};
2121
use arrow::{
22-
datatypes::{DataType, SchemaRef},
22+
datatypes::{DataType, Field, Schema, SchemaRef},
2323
error::ArrowError,
2424
ffi::FFI_ArrowSchema,
2525
};
2626
use datafusion::{
2727
error::{DataFusionError, Result},
2828
logical_expr::function::PartitionEvaluatorArgs,
29-
physical_plan::PhysicalExpr,
29+
physical_plan::{expressions::Column, PhysicalExpr},
3030
prelude::SessionContext,
3131
};
3232
use datafusion_proto::{
@@ -51,11 +51,40 @@ pub struct FFI_PartitionEvaluatorArgs {
5151
schema: WrappedSchema,
5252
}
5353

54-
impl FFI_PartitionEvaluatorArgs {
55-
pub fn new(
54+
55+
impl TryFrom<PartitionEvaluatorArgs<'_>> for FFI_PartitionEvaluatorArgs {
56+
57+
type Error = DataFusionError;
58+
fn try_from(
5659
args: PartitionEvaluatorArgs,
57-
schema: SchemaRef,
5860
) -> Result<Self, DataFusionError> {
61+
// This is a bit of a hack. Since PartitionEvaluatorArgs does not carry a schema
62+
// around, and instead passes the data types directly we are unable to decode the
63+
// protobuf PhysicalExpr correctly. In evaluating the code the only place these
64+
// appear to be really used are the Column data types. So here we will find all
65+
// of the required columns and create a schema that has empty fields except for
66+
// the ones we require. Ideally we would enhance PartitionEvaluatorArgs to just
67+
// pass along the schema, but that is a larger breaking change.
68+
let required_columns: HashMap<usize, (&str, &DataType)> = args.input_exprs().iter().zip(args.input_types())
69+
.filter_map(|(expr, data_type)| {
70+
expr.as_any().downcast_ref::<Column>().map(|column| (column.index(), (column.name(), data_type)))
71+
})
72+
.collect();
73+
74+
let max_column = required_columns.keys().max().unwrap_or(&0).to_owned();
75+
let fields: Vec<_> = (0..max_column).into_iter()
76+
.map(|idx| {
77+
match required_columns.get(&idx) {
78+
Some((name, data_type)) => {
79+
Field::new(*name, (*data_type).clone(), true)
80+
}
81+
None => {
82+
Field::new(format!("ffi_partition_evaluator_col_{idx}"), DataType::Null, true)
83+
}
84+
}
85+
}).collect();
86+
let schema = Arc::new(Schema::new(fields));
87+
5988
let codec = DefaultPhysicalExtensionCodec {};
6089
let input_exprs = serialize_physical_exprs(args.input_exprs(), &codec)?
6190
.into_iter()

datafusion/ffi/tests/ffi_integration.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,15 @@
2020
#[cfg(feature = "integration-tests")]
2121
mod tests {
2222
use datafusion::error::{DataFusionError, Result};
23-
use datafusion::prelude::SessionContext;
23+
use datafusion::prelude::{col, SessionContext};
2424
use datafusion_ffi::catalog_provider::ForeignCatalogProvider;
2525
use datafusion_ffi::table_provider::ForeignTableProvider;
2626
use datafusion_ffi::tests::create_record_batch;
2727
use datafusion_ffi::tests::utils::get_module;
28+
use datafusion::logical_expr::{ScalarUDF, WindowUDF};
29+
use datafusion_ffi::udf::ForeignScalarUDF;
30+
use datafusion_ffi::udwf::ForeignWindowUDF;
31+
use std::path::Path;
2832
use std::sync::Arc;
2933

3034
/// It is important that this test is in the `tests` directory and not in the
@@ -95,4 +99,41 @@ mod tests {
9599

96100
Ok(())
97101
}
102+
103+
#[tokio::test]
104+
async fn test_rank_udwf() -> Result<()> {
105+
let module = get_module()?;
106+
107+
let ffi_rank_func =
108+
module
109+
.create_rank_udwf()
110+
.ok_or(DataFusionError::NotImplemented(
111+
"External table provider failed to implement create_scalar_udf"
112+
.to_string(),
113+
))?();
114+
let foreign_rank_func: ForeignWindowUDF = (&ffi_rank_func).try_into()?;
115+
116+
let udwf: WindowUDF = foreign_rank_func.into();
117+
118+
let ctx = SessionContext::default();
119+
let df = ctx.read_batch(create_record_batch(-5, 5))?;
120+
121+
let df = df
122+
.with_column("rank_a", udwf.call(vec![]))?;
123+
124+
let result = df.collect().await?;
125+
126+
let expected = record_batch!(
127+
("a", Int32, vec![-5, -4, -3, -2, -1]),
128+
("b", Float64, vec![-5., -4., -3., -2., -1.]),
129+
("abs_a", Int32, vec![5, 4, 3, 2, 1]),
130+
("abs_b", Float64, vec![5., 4., 3., 2., 1.])
131+
)?;
132+
133+
assert!(result.len() == 1);
134+
assert!(result[0] == expected);
135+
136+
Ok(())
137+
138+
}
98139
}

0 commit comments

Comments
 (0)