Skip to content

Commit

Permalink
fix failed test from #12050
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Aug 20, 2024
1 parent 83ce363 commit 3519e75
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 14 deletions.
18 changes: 14 additions & 4 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,14 @@ impl ExprSchemable for Expr {
Expr::ScalarFunction(ScalarFunction { func, args }) => {
Ok(func.is_nullable(args, input_schema))
}
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
Ok(func.is_nullable())
Expr::AggregateFunction(AggregateFunction { func, args, .. }) => {
let nullables = args
.iter()
.map(|e| e.nullable(input_schema))
.collect::<Result<Vec<_>>>()?;
Ok(func.is_nullable(&nullables))
}
Expr::WindowFunction(WindowFunction { fun, .. }) => match fun {
Expr::WindowFunction(WindowFunction { fun, args, .. }) => match fun {
WindowFunctionDefinition::BuiltInWindowFunction(func) => {
if func.name() == "RANK"
|| func.name() == "NTILE"
Expand All @@ -352,7 +356,13 @@ impl ExprSchemable for Expr {
Ok(true)
}
}
WindowFunctionDefinition::AggregateUDF(func) => Ok(func.is_nullable()),
WindowFunctionDefinition::AggregateUDF(func) => {
let nullables = args
.iter()
.map(|e| e.nullable(input_schema))
.collect::<Result<Vec<_>>>()?;
Ok(func.is_nullable(&nullables))
}
WindowFunctionDefinition::WindowUDF(udwf) => Ok(udwf.nullable()),
},
Expr::ScalarVariable(_, _)
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ impl AggregateUDF {
self.inner.name()
}

pub fn is_nullable(&self) -> bool {
self.inner.is_nullable()
pub fn is_nullable(&self, nullables: &[bool]) -> bool {
self.inner.is_nullable(nullables)
}

/// Returns the aliases for this function.
Expand Down Expand Up @@ -355,8 +355,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
///
/// Nullable means that that the function could return `null` for any inputs.
/// For example, aggregate functions like `COUNT` always return a non null value
/// but others like `MIN` will return `NULL` if there is no non null input.
fn is_nullable(&self) -> bool {
/// but others like `MIN` will return `NULL` if there is nullable input.
fn is_nullable(&self, _nullables: &[bool]) -> bool {
true
}

Expand Down
3 changes: 2 additions & 1 deletion datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ impl AggregateUDFImpl for Count {
Ok(DataType::Int64)
}

fn is_nullable(&self) -> bool {
// Count is always nullable regardless of the input nullability
fn is_nullable(&self, _nullables: &[bool]) -> bool {
false
}

Expand Down
8 changes: 7 additions & 1 deletion datafusion/functions/src/core/arrow_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ use datafusion_common::{
};

use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility};
use datafusion_expr::{
ColumnarValue, Expr, ExprSchemable, ScalarUDFImpl, Signature, Volatility,
};

/// Implements casting to arbitrary arrow types (rather than SQL types)
///
Expand Down Expand Up @@ -87,6 +89,10 @@ impl ScalarUDFImpl for ArrowCastFunc {
internal_err!("arrow_cast should return type from exprs")
}

fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true))
}

fn return_type_from_exprs(
&self,
args: &[Expr],
Expand Down
10 changes: 6 additions & 4 deletions datafusion/physical-expr-functions-aggregate/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,10 @@ pub struct AggregateExprBuilder {
is_distinct: bool,
/// Whether the expression is reversed
is_reversed: bool,
is_nullable: bool,
}

impl AggregateExprBuilder {
pub fn new(fun: Arc<AggregateUDF>, args: Vec<Arc<dyn PhysicalExpr>>) -> Self {
let is_nullable = fun.is_nullable();
Self {
fun,
args,
Expand All @@ -71,7 +69,6 @@ impl AggregateExprBuilder {
ignore_nulls: false,
is_distinct: false,
is_reversed: false,
is_nullable,
}
}

Expand All @@ -85,7 +82,6 @@ impl AggregateExprBuilder {
ignore_nulls,
is_distinct,
is_reversed,
is_nullable,
} = self;
if args.is_empty() {
return internal_err!("args should not be empty");
Expand All @@ -107,13 +103,19 @@ impl AggregateExprBuilder {
.map(|arg| arg.data_type(&schema))
.collect::<Result<Vec<_>>>()?;

let input_nullables = args
.iter()
.map(|arg| arg.nullable(&schema))
.collect::<Result<Vec<_>>>()?;

check_arg_count(
fun.name(),
&input_exprs_types,
&fun.signature().type_signature,
)?;

let data_type = fun.return_type(&input_exprs_types)?;
let is_nullable = fun.is_nullable(&input_nullables);
let name = match alias {
// TODO: Ideally, we should build the name from physical expressions
None => create_function_physical_name(fun.name(), is_distinct, &[], None)?,
Expand Down

0 comments on commit 3519e75

Please sign in to comment.