Skip to content

Commit

Permalink
Make builtin window function output datatype to be derived from schema (
Browse files Browse the repository at this point in the history
apache#9686)

* Make builtin window function output datatype to be derived from schema
  • Loading branch information
comphead committed Mar 20, 2024
1 parent 89efc4a commit 1d0171a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 36 deletions.
22 changes: 10 additions & 12 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -742,13 +742,13 @@ impl DefaultPhysicalPlanner {
);
}

let logical_input_schema = input.schema();
let logical_schema = logical_plan.schema();
let window_expr = window_expr
.iter()
.map(|e| {
create_window_expr(
e,
logical_input_schema,
logical_schema,
session_state.execution_props(),
)
})
Expand Down Expand Up @@ -1578,11 +1578,11 @@ pub fn is_window_frame_bound_valid(window_frame: &WindowFrame) -> bool {
pub fn create_window_expr_with_name(
e: &Expr,
name: impl Into<String>,
logical_input_schema: &DFSchema,
logical_schema: &DFSchema,
execution_props: &ExecutionProps,
) -> Result<Arc<dyn WindowExpr>> {
let name = name.into();
let physical_input_schema: &Schema = &logical_input_schema.into();
let physical_schema: &Schema = &logical_schema.into();
match e {
Expr::WindowFunction(WindowFunction {
fun,
Expand All @@ -1594,17 +1594,15 @@ pub fn create_window_expr_with_name(
}) => {
let args = args
.iter()
.map(|e| create_physical_expr(e, logical_input_schema, execution_props))
.map(|e| create_physical_expr(e, logical_schema, execution_props))
.collect::<Result<Vec<_>>>()?;
let partition_by = partition_by
.iter()
.map(|e| create_physical_expr(e, logical_input_schema, execution_props))
.map(|e| create_physical_expr(e, logical_schema, execution_props))
.collect::<Result<Vec<_>>>()?;
let order_by = order_by
.iter()
.map(|e| {
create_physical_sort_expr(e, logical_input_schema, execution_props)
})
.map(|e| create_physical_sort_expr(e, logical_schema, execution_props))
.collect::<Result<Vec<_>>>()?;

if !is_window_frame_bound_valid(window_frame) {
Expand All @@ -1625,7 +1623,7 @@ pub fn create_window_expr_with_name(
&partition_by,
&order_by,
window_frame,
physical_input_schema,
physical_schema,
ignore_nulls,
)
}
Expand All @@ -1636,15 +1634,15 @@ pub fn create_window_expr_with_name(
/// Create a window expression from a logical expression or an alias
pub fn create_window_expr(
e: &Expr,
logical_input_schema: &DFSchema,
logical_schema: &DFSchema,
execution_props: &ExecutionProps,
) -> Result<Arc<dyn WindowExpr>> {
// unpack aliased logical expressions, e.g. "sum(col) over () as total"
let (name, e) = match e {
Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()),
_ => (e.display_name()?, e),
};
create_window_expr_with_name(e, name, logical_input_schema, execution_props)
create_window_expr_with_name(e, name, logical_schema, execution_props)
}

type AggregateExprWithOptionalArgs = (
Expand Down
39 changes: 36 additions & 3 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use arrow_schema::{Field, Schema};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::windows::{
Expand All @@ -39,6 +40,7 @@ use datafusion_expr::{
};
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use itertools::Itertools;
use test_utils::add_empty_batches;

use hashbrown::HashMap;
Expand Down Expand Up @@ -273,14 +275,17 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
window_frame.is_causal()
};

let extended_schema =
schema_add_window_fields(&args, &schema, &window_fn, fn_name)?;

let window_expr = create_window_expr(
&window_fn,
fn_name.to_string(),
&args,
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame),
schema.as_ref(),
&extended_schema,
false,
)?;
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
Expand Down Expand Up @@ -678,6 +683,8 @@ async fn run_window_test(
exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _;
}

let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?;

let usual_window_exec = Arc::new(WindowAggExec::try_new(
vec![create_window_expr(
&window_fn,
Expand All @@ -686,7 +693,7 @@ async fn run_window_test(
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
schema.as_ref(),
&extended_schema,
false,
)?],
exec1,
Expand All @@ -704,7 +711,7 @@ async fn run_window_test(
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
schema.as_ref(),
&extended_schema,
false,
)?],
exec2,
Expand Down Expand Up @@ -747,6 +754,32 @@ async fn run_window_test(
Ok(())
}

// The planner has fully updated schema before calling the `create_window_expr`
// Replicate the same for this test
fn schema_add_window_fields(
args: &[Arc<dyn PhysicalExpr>],
schema: &Arc<Schema>,
window_fn: &WindowFunctionDefinition,
fn_name: &str,
) -> Result<Arc<Schema>> {
let data_types = args
.iter()
.map(|e| e.clone().as_ref().data_type(schema))
.collect::<Result<Vec<_>>>()?;
let window_expr_return_type = window_fn.return_type(&data_types)?;
let mut window_fields = schema
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect_vec();
window_fields.extend_from_slice(&[Field::new(
fn_name,
window_expr_return_type,
true,
)]);
Ok(Arc::new(Schema::new(window_fields)))
}

/// Return randomly sized record batches with:
/// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns
/// one random int32 column x
Expand Down
47 changes: 26 additions & 21 deletions datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,15 @@ fn create_built_in_window_expr(
name: String,
ignore_nulls: bool,
) -> Result<Arc<dyn BuiltInWindowFunctionExpr>> {
// need to get the types into an owned vec for some reason
let input_types: Vec<_> = args
.iter()
.map(|arg| arg.data_type(input_schema))
.collect::<Result<_>>()?;
// derive the output datatype from incoming schema
let out_data_type: &DataType = input_schema.field_with_name(&name)?.data_type();

// figure out the output type
let data_type = &fun.return_type(&input_types)?;
Ok(match fun {
BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, data_type)),
BuiltInWindowFunction::Rank => Arc::new(rank(name, data_type)),
BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, data_type)),
BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, data_type)),
BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, data_type)),
BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, out_data_type)),
BuiltInWindowFunction::Rank => Arc::new(rank(name, out_data_type)),
BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, out_data_type)),
BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, out_data_type)),
BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, out_data_type)),
BuiltInWindowFunction::Ntile => {
let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| {
DataFusionError::Execution(
Expand All @@ -201,13 +196,13 @@ fn create_built_in_window_expr(

if n.is_unsigned() {
let n: u64 = n.try_into()?;
Arc::new(Ntile::new(name, n, data_type))
Arc::new(Ntile::new(name, n, out_data_type))
} else {
let n: i64 = n.try_into()?;
if n <= 0 {
return exec_err!("NTILE requires a positive integer");
}
Arc::new(Ntile::new(name, n as u64, data_type))
Arc::new(Ntile::new(name, n as u64, out_data_type))
}
}
BuiltInWindowFunction::Lag => {
Expand All @@ -216,10 +211,10 @@ fn create_built_in_window_expr(
.map(|v| v.try_into())
.and_then(|v| v.ok());
let default_value =
get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?;
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
Arc::new(lag(
name,
data_type.clone(),
out_data_type.clone(),
arg,
shift_offset,
default_value,
Expand All @@ -232,10 +227,10 @@ fn create_built_in_window_expr(
.map(|v| v.try_into())
.and_then(|v| v.ok());
let default_value =
get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?;
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
Arc::new(lead(
name,
data_type.clone(),
out_data_type.clone(),
arg,
shift_offset,
default_value,
Expand All @@ -252,18 +247,28 @@ fn create_built_in_window_expr(
Arc::new(NthValue::nth(
name,
arg,
data_type.clone(),
out_data_type.clone(),
n,
ignore_nulls,
)?)
}
BuiltInWindowFunction::FirstValue => {
let arg = args[0].clone();
Arc::new(NthValue::first(name, arg, data_type.clone(), ignore_nulls))
Arc::new(NthValue::first(
name,
arg,
out_data_type.clone(),
ignore_nulls,
))
}
BuiltInWindowFunction::LastValue => {
let arg = args[0].clone();
Arc::new(NthValue::last(name, arg, data_type.clone(), ignore_nulls))
Arc::new(NthValue::last(
name,
arg,
out_data_type.clone(),
ignore_nulls,
))
}
})
}
Expand Down

0 comments on commit 1d0171a

Please sign in to comment.