Skip to content
Merged
1 change: 1 addition & 0 deletions datafusion/catalog-listing/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool {
| Expr::Exists(_)
| Expr::InSubquery(_)
| Expr::ScalarSubquery(_)
| Expr::SetComparison(_)
| Expr::GroupingSet(_)
| Expr::Case(_) => Ok(TreeNodeRecursion::Continue),

Expand Down
82 changes: 82 additions & 0 deletions datafusion/core/tests/set_comparison.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion::prelude::SessionContext;
use datafusion_common::{assert_batches_eq, Result};

fn build_table(values: &[i32]) -> Result<RecordBatch> {
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)]));
let array =
Arc::new(Int32Array::from(values.to_vec())) as Arc<dyn arrow::array::Array>;
RecordBatch::try_new(schema, vec![array]).map_err(Into::into)
}

#[tokio::test]
async fn set_comparison_any() -> Result<()> {
let ctx = SessionContext::new();

ctx.register_batch("t", build_table(&[1, 6, 10])?)?;
// Include a NULL in the subquery input to ensure we propagate UNKNOWN correctly.
ctx.register_batch("s", {
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)]));
let array = Arc::new(Int32Array::from(vec![Some(5), None]))
as Arc<dyn arrow::array::Array>;
RecordBatch::try_new(schema, vec![array])?
})?;

let df = ctx
.sql("select v from t where v > any(select v from s)")
.await?;
let results = df.collect().await?;

assert_batches_eq!(
&["+----+", "| v |", "+----+", "| 6 |", "| 10 |", "+----+",],
&results
);
Ok(())
}

#[tokio::test]
async fn set_comparison_all_empty() -> Result<()> {
let ctx = SessionContext::new();
Comment on lines +77 to +78

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about adding

async fn set_comparison_type_mismatch() -> Result<()> {
    // SELECT v FROM t WHERE v > ANY (SELECT s FROM strings)
    // INT > STRING should error with clear message

...

too?


ctx.register_batch("t", build_table(&[1, 6, 10])?)?;
ctx.register_batch(
"e",
RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new(
"v",
DataType::Int32,
true,
)]))),
)?;

let df = ctx
.sql("select v from t where v < all(select v from e)")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think tests for:

  • Multiple operators (=, !=, >=, <=)
  • NULL semantics (e.g., 5 != ALL (1, NULL)
    would improve test coverage

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @kosiew, I added above cases in 07e23bd

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@waynexia
Thanks for the ping.
I will check again after you fix the clippy errors.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kosiew, CI is all green now, please take another look, thank you!

.await?;
let results = df.collect().await?;

assert_batches_eq!(
&["+----+", "| v |", "+----+", "| 1 |", "| 6 |", "| 10 |", "+----+",],
&results
);
Ok(())
}
97 changes: 97 additions & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ pub enum Expr {
Exists(Exists),
/// IN subquery
InSubquery(InSubquery),
/// Set comparison subquery (e.g. `= ANY`, `> ALL`)
SetComparison(SetComparison),
/// Scalar subquery
ScalarSubquery(Subquery),
/// Represents a reference to all available fields in a specific schema,
Expand Down Expand Up @@ -1101,6 +1103,54 @@ impl Exists {
}
}

/// Whether the set comparison uses `ANY`/`SOME` or `ALL`
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub enum SetQuantifier {
/// `ANY` (or `SOME`)
Any,
/// `ALL`
All,
}

impl Display for SetQuantifier {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
SetQuantifier::Any => write!(f, "ANY"),
SetQuantifier::All => write!(f, "ALL"),
}
}
}

/// Set comparison subquery (e.g. `= ANY`, `> ALL`)
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct SetComparison {
/// The expression to compare
pub expr: Box<Expr>,
/// Subquery that will produce a single column of data to compare against
pub subquery: Subquery,
/// Comparison operator (e.g. `=`, `>`, `<`)
pub op: Operator,
/// Quantifier (`ANY`/`ALL`)
pub quantifier: SetQuantifier,
}

impl SetComparison {
/// Create a new set comparison expression
pub fn new(
expr: Box<Expr>,
subquery: Subquery,
op: Operator,
quantifier: SetQuantifier,
) -> Self {
Self {
expr,
subquery,
op,
quantifier,
}
}
}

/// InList expression
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct InList {
Expand Down Expand Up @@ -1503,6 +1553,7 @@ impl Expr {
Expr::GroupingSet(..) => "GroupingSet",
Expr::InList { .. } => "InList",
Expr::InSubquery(..) => "InSubquery",
Expr::SetComparison(..) => "SetComparison",
Expr::IsNotNull(..) => "IsNotNull",
Expr::IsNull(..) => "IsNull",
Expr::Like { .. } => "Like",
Expand Down Expand Up @@ -2058,6 +2109,7 @@ impl Expr {
| Expr::GroupingSet(..)
| Expr::InList(..)
| Expr::InSubquery(..)
| Expr::SetComparison(..)
| Expr::IsFalse(..)
| Expr::IsNotFalse(..)
| Expr::IsNotNull(..)
Expand Down Expand Up @@ -2651,6 +2703,16 @@ impl HashNode for Expr {
subquery.hash(state);
negated.hash(state);
}
Expr::SetComparison(SetComparison {
expr: _,
subquery,
op,
quantifier,
}) => {
subquery.hash(state);
op.hash(state);
quantifier.hash(state);
}
Expr::ScalarSubquery(subquery) => {
subquery.hash(state);
}
Expand Down Expand Up @@ -2841,6 +2903,12 @@ impl Display for SchemaDisplay<'_> {
write!(f, "NOT IN")
}
Expr::InSubquery(InSubquery { negated: false, .. }) => write!(f, "IN"),
Expr::SetComparison(SetComparison {
expr,
op,
quantifier,
..
}) => write!(f, "{} {op} {quantifier}", SchemaDisplay(expr.as_ref())),
Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SchemaDisplay(expr)),
Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SchemaDisplay(expr)),
Expr::IsNotTrue(expr) => {
Expand Down Expand Up @@ -3316,6 +3384,12 @@ impl Display for Expr {
subquery,
negated: false,
}) => write!(f, "{expr} IN ({subquery:?})"),
Expr::SetComparison(SetComparison {
expr,
subquery,
op,
quantifier,
}) => write!(f, "{expr} {op} {quantifier} ({subquery:?})"),
Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"),
Expr::BinaryExpr(expr) => write!(f, "{expr}"),
Expr::ScalarFunction(fun) => {
Expand Down Expand Up @@ -3799,6 +3873,7 @@ mod test {
}

use super::*;
use crate::logical_plan::{EmptyRelation, LogicalPlan};

#[test]
fn test_display_wildcard() {
Expand Down Expand Up @@ -3889,6 +3964,28 @@ mod test {
)
}

#[test]
fn test_display_set_comparison() {
let subquery = Subquery {
subquery: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
})),
outer_ref_columns: vec![],
spans: Spans::new(),
};

let expr = Expr::SetComparison(SetComparison::new(
Box::new(Expr::Column(Column::from_name("a"))),
subquery,
Operator::Gt,
SetQuantifier::Any,
));

assert_eq!(format!("{expr}"), "a > ANY (<subquery>)");
assert_eq!(format!("{}", expr.human_display()), "a > ANY (<subquery>)");
}

#[test]
fn test_schema_display_alias_with_relation() {
assert_eq!(
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ impl ExprSchemable for Expr {
| Expr::IsNull(_)
| Expr::Exists { .. }
| Expr::InSubquery(_)
| Expr::SetComparison(_)
| Expr::Between { .. }
| Expr::InList { .. }
| Expr::IsNotNull(_)
Expand Down Expand Up @@ -380,6 +381,7 @@ impl ExprSchemable for Expr {
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Exists { .. } => Ok(false),
Expr::SetComparison(_) => Ok(true),
Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema),
Expr::ScalarSubquery(subquery) => {
Ok(subquery.subquery.schema().field(0).is_nullable())
Expand Down Expand Up @@ -645,6 +647,7 @@ impl ExprSchemable for Expr {
| Expr::TryCast(_)
| Expr::InList(_)
| Expr::InSubquery(_)
| Expr::SetComparison(_)
| Expr::Wildcard { .. }
| Expr::GroupingSet(_)
| Expr::Placeholder(_)
Expand Down
20 changes: 18 additions & 2 deletions datafusion/expr/src/logical_plan/invariants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use datafusion_common::{

use crate::{
Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window,
expr::{Exists, InSubquery},
expr::{Exists, InSubquery, SetComparison},
expr_rewriter::strip_outer_reference,
utils::{collect_subquery_cols, split_conjunction},
};
Expand Down Expand Up @@ -81,6 +81,7 @@ fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Re
match expr {
Expr::Exists(Exists { subquery, .. })
| Expr::InSubquery(InSubquery { subquery, .. })
| Expr::SetComparison(SetComparison { subquery, .. })
| Expr::ScalarSubquery(subquery) => {
assert_valid_extension_nodes(&subquery.subquery, check)?;
}
Expand Down Expand Up @@ -133,6 +134,7 @@ fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> {
match expr {
Expr::Exists(Exists { subquery, .. })
| Expr::InSubquery(InSubquery { subquery, .. })
| Expr::SetComparison(SetComparison { subquery, .. })
| Expr::ScalarSubquery(subquery) => {
check_subquery_expr(plan, &subquery.subquery, expr)?;
}
Expand Down Expand Up @@ -229,6 +231,20 @@ pub fn check_subquery_expr(
);
}
}
if let Expr::SetComparison(set_comparison) = expr
&& set_comparison.subquery.subquery.schema().fields().len() > 1
{
return plan_err!(
"Set comparison subquery should only return one column, but found {}: {}",
set_comparison.subquery.subquery.schema().fields().len(),
set_comparison
.subquery
.subquery
.schema()
.field_names()
.join(", ")
);
}
match outer_plan {
LogicalPlan::Projection(_)
| LogicalPlan::Filter(_)
Expand All @@ -237,7 +253,7 @@ pub fn check_subquery_expr(
| LogicalPlan::Aggregate(_)
| LogicalPlan::Join(_) => Ok(()),
_ => plan_err!(
"In/Exist subquery can only be used in \
"In/Exist/SetComparison subquery can only be used in \
Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \
but was used in [{}]",
outer_plan.display()
Expand Down
19 changes: 18 additions & 1 deletion datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use crate::{
};
use datafusion_common::tree_node::TreeNodeRefContainer;

use crate::expr::{Exists, InSubquery};
use crate::expr::{Exists, InSubquery, SetComparison};
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator, TreeNodeRecursion,
TreeNodeRewriter, TreeNodeVisitor,
Expand Down Expand Up @@ -815,6 +815,7 @@ impl LogicalPlan {
expr.apply(|expr| match expr {
Expr::Exists(Exists { subquery, .. })
| Expr::InSubquery(InSubquery { subquery, .. })
| Expr::SetComparison(SetComparison { subquery, .. })
| Expr::ScalarSubquery(subquery) => {
// use a synthetic plan so the collector sees a
// LogicalPlan::Subquery (even though it is
Expand Down Expand Up @@ -856,6 +857,22 @@ impl LogicalPlan {
})),
_ => internal_err!("Transformation should return Subquery"),
}),
Expr::SetComparison(SetComparison {
expr,
subquery,
op,
quantifier,
}) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s {
LogicalPlan::Subquery(subquery) => {
Ok(Expr::SetComparison(SetComparison {
expr,
subquery,
op,
quantifier,
}))
}
_ => internal_err!("Transformation should return Subquery"),
}),
Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))?
.map_data(|s| match s {
LogicalPlan::Subquery(subquery) => {
Expand Down
7 changes: 0 additions & 7 deletions datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,6 @@ pub trait ExprPlanner: Debug + Send + Sync {
)
}

/// Plans `ANY` expression, such as `expr = ANY(array_expr)`
///
/// Returns origin binary expression if not possible
fn plan_any(&self, expr: RawBinaryExpr) -> Result<PlannerResult<RawBinaryExpr>> {
Ok(PlannerResult::Original(expr))
}

/// Plans aggregate functions, such as `COUNT(<expr>)`
///
/// Returns original expression arguments if not possible
Expand Down
Loading