diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index a0a713765396..bf7e7934a0b8 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -834,6 +834,7 @@ dependencies = [ "indexmap", "itertools", "lazy_static", + "libc", "md-5", "paste", "petgraph", diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 78c9aca51561..670a25837a22 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -1548,7 +1548,7 @@ impl SymmetricHashJoinStream { mod tests { use std::fs::File; - use arrow::array::{ArrayRef, IntervalDayTimeArray}; + use arrow::array::{ArrayRef, Float64Array, IntervalDayTimeArray}; use arrow::array::{Int32Array, TimestampMillisecondArray}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; @@ -1559,7 +1559,7 @@ mod tests { use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{binary, col, Column}; use datafusion_physical_expr::intervals::test_utils::{ - gen_conjunctive_numeric_expr, gen_conjunctive_temporal_expr, + gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr, }; use datafusion_physical_expr::PhysicalExpr; @@ -1711,127 +1711,184 @@ mod tests { Ok(result) } - fn join_expr_tests_fixture( - expr_id: usize, - left_col: Arc, - right_col: Arc, - ) -> Arc { - match expr_id { - // left_col + 1 > right_col + 5 AND left_col + 3 < right_col + 10 - 0 => gen_conjunctive_numeric_expr( - left_col, - right_col, - Operator::Plus, - Operator::Plus, - Operator::Plus, - Operator::Plus, - 1, - 5, - 3, - 10, - (Operator::Gt, Operator::Lt), - ), - // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10 - 1 => gen_conjunctive_numeric_expr( - left_col, - right_col, - Operator::Minus, - Operator::Plus, - Operator::Plus, - Operator::Plus, - 1, - 5, - 3, - 10, - (Operator::Gt, Operator::Lt), - ), - // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10 - 2 => gen_conjunctive_numeric_expr( - left_col, - right_col, - Operator::Minus, - Operator::Plus, - Operator::Minus, - Operator::Plus, - 1, - 5, - 3, - 10, - (Operator::Gt, Operator::Lt), - ), - // left_col - 10 > right_col - 5 AND left_col - 3 < right_col + 10 - 3 => gen_conjunctive_numeric_expr( - left_col, - right_col, - Operator::Minus, - Operator::Minus, - Operator::Minus, - Operator::Plus, - 10, - 5, - 3, - 10, - (Operator::Gt, Operator::Lt), - ), - // left_col - 10 > right_col - 5 AND left_col - 30 < right_col - 3 - 4 => gen_conjunctive_numeric_expr( - left_col, - right_col, - Operator::Minus, - Operator::Minus, - Operator::Minus, - Operator::Minus, - 10, - 5, - 30, - 3, - (Operator::Gt, Operator::Lt), - ), - // left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col - 3 - 5 => gen_conjunctive_numeric_expr( - left_col, - right_col, - Operator::Minus, - Operator::Plus, - Operator::Plus, - Operator::Minus, - 2, - 5, - 7, - 3, - (Operator::GtEq, Operator::LtEq), - ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 - 6 => gen_conjunctive_numeric_expr( - left_col, - right_col, - Operator::Plus, - Operator::Minus, - Operator::Plus, - Operator::Plus, - 28, - 11, - 21, - 39, - (Operator::Gt, Operator::LtEq), - ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 - 7 => gen_conjunctive_numeric_expr( - left_col, - right_col, - Operator::Plus, - Operator::Minus, - Operator::Minus, - Operator::Plus, - 28, - 11, - 21, - 39, - (Operator::GtEq, Operator::Lt), - ), - _ => unreachable!(), + // It creates join filters for different type of fields for testing. + macro_rules! join_expr_tests { + ($func_name:ident, $type:ty, $SCALAR:ident) => { + fn $func_name( + expr_id: usize, + left_col: Arc, + right_col: Arc, + ) -> Arc { + match expr_id { + // left_col + 1 > right_col + 5 AND left_col + 3 < right_col + 10 + 0 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Plus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(1 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(10 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10 + 1 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(1 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(10 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10 + 2 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Plus, + Operator::Minus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(1 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(10 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 10 > right_col - 5 AND left_col - 3 < right_col + 10 + 3 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(10 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(10 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 10 > right_col - 5 AND left_col - 30 < right_col - 3 + 4 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ), + ScalarValue::$SCALAR(Some(10 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(30 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col - 3 + 5 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Plus, + Operator::Plus, + Operator::Minus, + ), + ScalarValue::$SCALAR(Some(2 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(7 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + (Operator::GtEq, Operator::LtEq), + ), + // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 + 6 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Plus, + Operator::Minus, + Operator::Plus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(28 as $type)), + ScalarValue::$SCALAR(Some(11 as $type)), + ScalarValue::$SCALAR(Some(21 as $type)), + ScalarValue::$SCALAR(Some(39 as $type)), + (Operator::Gt, Operator::LtEq), + ), + // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 + 7 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Plus, + Operator::Minus, + Operator::Minus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(28 as $type)), + ScalarValue::$SCALAR(Some(11 as $type)), + ScalarValue::$SCALAR(Some(21 as $type)), + ScalarValue::$SCALAR(Some(39 as $type)), + (Operator::GtEq, Operator::Lt), + ), + _ => panic!("No case"), + } + } + }; + } + + join_expr_tests!(join_expr_tests_fixture_i32, i32, Int32); + join_expr_tests!(join_expr_tests_fixture_f64, f64, Float64); + + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::iter::Iterator; + + struct AscendingRandomFloatIterator { + prev: f64, + max: f64, + rng: StdRng, + } + + impl AscendingRandomFloatIterator { + fn new(min: f64, max: f64) -> Self { + let mut rng = StdRng::seed_from_u64(42); + let initial = rng.gen_range(min..max); + AscendingRandomFloatIterator { + prev: initial, + max, + rng, + } } } + + impl Iterator for AscendingRandomFloatIterator { + type Item = f64; + + fn next(&mut self) -> Option { + let value = self.rng.gen_range(self.prev..self.max); + self.prev = value; + Some(value) + } + } + fn join_expr_tests_fixture_temporal( expr_id: usize, left_col: Arc, @@ -1887,12 +1944,18 @@ mod tests { let cardinality = Arc::new(Int32Array::from_iter( initial_range.clone().map(|x| x % 4).collect::>(), )); - let cardinality_key = Arc::new(Int32Array::from_iter( + let cardinality_key_left = Arc::new(Int32Array::from_iter( initial_range .clone() .map(|x| x % key_cardinality.0) .collect::>(), )); + let cardinality_key_right = Arc::new(Int32Array::from_iter( + initial_range + .clone() + .map(|x| x % key_cardinality.1) + .collect::>(), + )); let ordered_asc_null_first = Arc::new(Int32Array::from_iter({ std::iter::repeat(None) .take(index as usize) @@ -1926,10 +1989,15 @@ mod tests { .collect::>(), )); + let float_asc = Arc::new(Float64Array::from_iter_values( + AscendingRandomFloatIterator::new(0., table_size as f64) + .take(table_size as usize), + )); + let left = RecordBatch::try_from_iter(vec![ ("la1", ordered.clone()), ("lb1", cardinality.clone()), - ("lc1", cardinality_key.clone()), + ("lc1", cardinality_key_left), ("lt1", time.clone()), ("la2", ordered.clone()), ("la1_des", ordered_des.clone()), @@ -1937,11 +2005,12 @@ mod tests { ("l_asc_null_last", ordered_asc_null_last.clone()), ("l_desc_null_first", ordered_desc_null_first.clone()), ("li1", interval_time.clone()), + ("l_float", float_asc.clone()), ])?; let right = RecordBatch::try_from_iter(vec![ ("ra1", ordered.clone()), ("rb1", cardinality), - ("rc1", cardinality_key), + ("rc1", cardinality_key_right), ("rt1", time), ("ra2", ordered), ("ra1_des", ordered_des), @@ -1949,6 +2018,7 @@ mod tests { ("r_asc_null_last", ordered_asc_null_last), ("r_desc_null_first", ordered_desc_null_first), ("ri1", interval_time), + ("r_float", float_asc), ])?; Ok((left, right)) } @@ -2140,7 +2210,7 @@ mod tests { Field::new("left", DataType::Int32, true), Field::new("right", DataType::Int32, true), ]); - let filter_expr = join_expr_tests_fixture( + let filter_expr = join_expr_tests_fixture_i32( case_expr, col("left", &intermediate_schema)?, col("right", &intermediate_schema)?, @@ -2201,7 +2271,7 @@ mod tests { Field::new("left", DataType::Int32, true), Field::new("right", DataType::Int32, true), ]); - let filter_expr = join_expr_tests_fixture( + let filter_expr = join_expr_tests_fixture_i32( case_expr, col("left", &intermediate_schema)?, col("right", &intermediate_schema)?, @@ -2312,7 +2382,7 @@ mod tests { Field::new("left", DataType::Int32, true), Field::new("right", DataType::Int32, true), ]); - let filter_expr = join_expr_tests_fixture( + let filter_expr = join_expr_tests_fixture_i32( case_expr, col("left", &intermediate_schema)?, col("right", &intermediate_schema)?, @@ -2537,7 +2607,7 @@ mod tests { Field::new("left", DataType::Int32, true), Field::new("right", DataType::Int32, true), ]); - let filter_expr = join_expr_tests_fixture( + let filter_expr = join_expr_tests_fixture_i32( case_expr, col("left", &intermediate_schema)?, col("right", &intermediate_schema)?, @@ -2600,7 +2670,7 @@ mod tests { Field::new("left", DataType::Int32, true), Field::new("right", DataType::Int32, true), ]); - let filter_expr = join_expr_tests_fixture( + let filter_expr = join_expr_tests_fixture_i32( case_expr, col("left", &intermediate_schema)?, col("right", &intermediate_schema)?, @@ -2664,7 +2734,7 @@ mod tests { Field::new("left", DataType::Int32, true), Field::new("right", DataType::Int32, true), ]); - let filter_expr = join_expr_tests_fixture( + let filter_expr = join_expr_tests_fixture_i32( case_expr, col("left", &intermediate_schema)?, col("right", &intermediate_schema)?, @@ -2802,17 +2872,19 @@ mod tests { Field::new("0", DataType::Int32, true), Field::new("1", DataType::Int32, true), ]); - let filter_expr = gen_conjunctive_numeric_expr( + let filter_expr = gen_conjunctive_numerical_expr( col("0", &intermediate_schema)?, col("1", &intermediate_schema)?, - Operator::Plus, - Operator::Minus, - Operator::Plus, - Operator::Plus, - 0, - 3, - 0, - 3, + ( + Operator::Plus, + Operator::Minus, + Operator::Plus, + Operator::Plus, + ), + ScalarValue::Int32(Some(0)), + ScalarValue::Int32(Some(3)), + ScalarValue::Int32(Some(0)), + ScalarValue::Int32(Some(3)), (Operator::Gt, Operator::Lt), ); let column_indices = vec![ @@ -3033,4 +3105,78 @@ mod tests { Ok(()) } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn testing_ascending_float_pruning( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values( + (4, 5), + (99, 12), + )] + cardinality: (i32, i32), + #[values(0, 1, 2, 3, 4, 5, 6, 7)] case_expr: usize, + ) -> Result<()> { + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("l_float", left_schema)?, + options: SortOptions::default(), + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("r_float", right_schema)?, + options: SortOptions::default(), + }]; + let (left, right) = create_memory_table( + left_batch, + right_batch, + Some(left_sorted), + Some(right_sorted), + 13, + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Float64, true), + Field::new("right", DataType::Float64, true), + ]); + let filter_expr = join_expr_tests_fixture_f64( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 10, // l_float + side: JoinSide::Left, + }, + ColumnIndex { + index: 10, // r_float + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } } diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index b28ad534fbd2..31484bf7934f 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -58,6 +58,7 @@ hashbrown = { version = "0.13", features = ["raw"] } indexmap = "1.9.2" itertools = { version = "0.10", features = ["use_std"] } lazy_static = { version = "^1.4.0" } +libc = "0.2.140" md-5 = { version = "^0.10.0", optional = true } paste = "^1.0" petgraph = "0.6.2" diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 65c8850b39d6..3a682049a08f 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -552,10 +552,10 @@ pub fn check_support(expr: &Arc) -> bool { #[cfg(test)] mod tests { use super::*; - use crate::intervals::test_utils::gen_conjunctive_numeric_expr; use itertools::Itertools; use crate::expressions::{BinaryExpr, Column}; + use crate::intervals::test_utils::gen_conjunctive_numerical_expr; use datafusion_common::ScalarValue; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -564,31 +564,19 @@ mod tests { fn experiment( expr: Arc, exprs_with_interval: (Arc, Arc), - left_interval: (Option, Option), - right_interval: (Option, Option), - left_expected: (Option, Option), - right_expected: (Option, Option), + left_interval: Interval, + right_interval: Interval, + left_expected: Interval, + right_expected: Interval, result: PropagationResult, ) -> Result<()> { let col_stats = vec![ - ( - exprs_with_interval.0.clone(), - Interval::make(left_interval.0, left_interval.1, (false, false)), - ), - ( - exprs_with_interval.1.clone(), - Interval::make(right_interval.0, right_interval.1, (false, false)), - ), + (exprs_with_interval.0.clone(), left_interval), + (exprs_with_interval.1.clone(), right_interval), ]; let expected = vec![ - ( - exprs_with_interval.0.clone(), - Interval::make(left_expected.0, left_expected.1, (false, false)), - ), - ( - exprs_with_interval.1.clone(), - Interval::make(right_expected.0, right_expected.1, (false, false)), - ), + (exprs_with_interval.0.clone(), left_expected), + (exprs_with_interval.1.clone(), right_expected), ]; let mut graph = ExprIntervalGraph::try_new(expr)?; let expr_indexes = graph @@ -608,81 +596,71 @@ mod tests { let exp_result = graph.update_ranges(&mut col_stat_nodes[..])?; assert_eq!(exp_result, result); col_stat_nodes.iter().zip(expected_nodes.iter()).for_each( - |((_, res), (_, expected))| { - // NOTE: These randomized tests only check the correnctness of - // endpoint values, not open/closedness. - assert_eq!(res.lower.value, expected.lower.value); - assert_eq!(res.upper.value, expected.upper.value); + |((_, calculated_interval_node), (_, expected))| { + // NOTE: These randomized tests only check for conservative containment, + // not openness/closedness of endpoints. + assert!(calculated_interval_node.lower.value <= expected.lower.value); + assert!(calculated_interval_node.upper.value >= expected.upper.value); }, ); Ok(()) } - fn generate_case( - expr: Arc, - left_col: Arc, - right_col: Arc, - seed: u64, - expr_left: i32, - expr_right: i32, - ) -> Result<()> { - let mut r = StdRng::seed_from_u64(seed); - - let (left_interval, right_interval, left_waited, right_waited) = if ASC { - let left = (Some(r.gen_range(0..1000)), None); - let right = (Some(r.gen_range(0..1000)), None); - ( - left, - right, - ( - Some(std::cmp::max(left.0.unwrap(), right.0.unwrap() + expr_left)), - None, - ), - ( - Some(std::cmp::max( - right.0.unwrap(), - left.0.unwrap() + expr_right, - )), - None, - ), - ) - } else { - let left = (None, Some(r.gen_range(0..1000))); - let right = (None, Some(r.gen_range(0..1000))); - ( - left, - right, - ( - None, - Some(std::cmp::min(left.1.unwrap(), right.1.unwrap() + expr_left)), - ), - ( - None, - Some(std::cmp::min( - right.1.unwrap(), - left.1.unwrap() + expr_right, - )), - ), - ) + macro_rules! generate_cases { + ($FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => { + fn $FUNC_NAME( + expr: Arc, + left_col: Arc, + right_col: Arc, + seed: u64, + expr_left: $TYPE, + expr_right: $TYPE, + ) -> Result<()> { + let mut r = StdRng::seed_from_u64(seed); + + let (left_given, right_given, left_expected, right_expected) = if ASC { + let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); + let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); + ( + (Some(left), None), + (Some(right), None), + (Some(<$TYPE>::max(left, right + expr_left)), None), + (Some(<$TYPE>::max(right, left + expr_right)), None), + ) + } else { + let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); + let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); + ( + (None, Some(left)), + (None, Some(right)), + (None, Some(<$TYPE>::min(left, right + expr_left))), + (None, Some(<$TYPE>::min(right, left + expr_right))), + ) + }; + + experiment( + expr, + (left_col, right_col), + Interval::make(left_given.0, left_given.1, (true, true)), + Interval::make(right_given.0, right_given.1, (true, true)), + Interval::make(left_expected.0, left_expected.1, (true, true)), + Interval::make(right_expected.0, right_expected.1, (true, true)), + PropagationResult::Success, + ) + } }; - experiment( - expr, - (left_col, right_col), - left_interval, - right_interval, - left_waited, - right_waited, - PropagationResult::Success, - )?; - Ok(()) } + generate_cases!(generate_case_i32, i32, Int32); + generate_cases!(generate_case_i64, i64, Int64); + generate_cases!(generate_case_f32, f32, Float32); + generate_cases!(generate_case_f64, f64, Float64); #[test] fn testing_not_possible() -> Result<()> { let left_col = Arc::new(Column::new("left_watermark", 0)); let right_col = Arc::new(Column::new("right_watermark", 0)); - // left_watermark > right_watermark + 5 + // left_watermark > right_watermark + 5 let left_and_1 = Arc::new(BinaryExpr::new( left_col.clone(), Operator::Plus, @@ -692,341 +670,293 @@ mod tests { experiment( expr, (left_col, right_col), - (Some(10), Some(20)), - (Some(100), None), - (Some(10), Some(20)), - (Some(100), None), + Interval::make(Some(10), Some(20), (true, true)), + Interval::make(Some(100), None, (true, true)), + Interval::make(Some(10), Some(20), (true, true)), + Interval::make(Some(100), None, (true, true)), PropagationResult::Infeasible, - )?; - Ok(()) - } - - #[rstest] - #[test] - fn case_1( - #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, - ) -> Result<()> { - let left_col = Arc::new(Column::new("left_watermark", 0)); - let right_col = Arc::new(Column::new("right_watermark", 0)); - // left_watermark + 1 > right_watermark + 11 AND left_watermark + 3 < right_watermark + 33 - let expr = gen_conjunctive_numeric_expr( - left_col.clone(), - right_col.clone(), - Operator::Plus, - Operator::Plus, - Operator::Plus, - Operator::Plus, - 1, - 11, - 3, - 33, - (Operator::Gt, Operator::Lt), - ); - // l > r + 10 AND r > l - 30 - let l_gt_r = 10; - let r_gt_l = -30; - generate_case::( - expr.clone(), - left_col.clone(), - right_col.clone(), - seed, - l_gt_r, - r_gt_l, - )?; - // Descending tests - // r < l - 10 AND l < r + 30 - let r_lt_l = -l_gt_r; - let l_lt_r = -r_gt_l; - generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; - - Ok(()) - } - #[rstest] - #[test] - fn case_2( - #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, - ) -> Result<()> { - let left_col = Arc::new(Column::new("left_watermark", 0)); - let right_col = Arc::new(Column::new("right_watermark", 0)); - // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 - let expr = gen_conjunctive_numeric_expr( - left_col.clone(), - right_col.clone(), - Operator::Minus, - Operator::Plus, - Operator::Plus, - Operator::Plus, - 1, - 5, - 3, - 10, - (Operator::Gt, Operator::Lt), - ); - // l > r + 6 AND r > l - 7 - let l_gt_r = 6; - let r_gt_l = -7; - generate_case::( - expr.clone(), - left_col.clone(), - right_col.clone(), - seed, - l_gt_r, - r_gt_l, - )?; - // Descending tests - // r < l - 6 AND l < r + 7 - let r_lt_l = -l_gt_r; - let l_lt_r = -r_gt_l; - generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; - - Ok(()) + ) } - #[rstest] - #[test] - fn case_3( - #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, - ) -> Result<()> { - let left_col = Arc::new(Column::new("left_watermark", 0)); - let right_col = Arc::new(Column::new("right_watermark", 0)); - // left_watermark - 1 > right_watermark + 5 AND left_watermark - 3 < right_watermark + 10 - let expr = gen_conjunctive_numeric_expr( - left_col.clone(), - right_col.clone(), - Operator::Minus, - Operator::Plus, - Operator::Minus, - Operator::Plus, - 1, - 5, - 3, - 10, - (Operator::Gt, Operator::Lt), - ); - // l > r + 6 AND r > l - 13 - let l_gt_r = 6; - let r_gt_l = -13; - generate_case::( - expr.clone(), - left_col.clone(), - right_col.clone(), - seed, - l_gt_r, - r_gt_l, - )?; - // Descending tests - // r < l - 6 AND l < r + 13 - let r_lt_l = -l_gt_r; - let l_lt_r = -r_gt_l; - generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; - - Ok(()) + macro_rules! integer_float_case_1 { + ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => { + #[rstest] + #[test] + fn $TEST_FUNC_NAME( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)] + seed: u64, + #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator, + #[values(Operator::Lt, Operator::LtEq)] less_op: Operator, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + + // left_watermark + 1 > right_watermark + 11 AND left_watermark + 3 < right_watermark + 33 + let expr = gen_conjunctive_numerical_expr( + left_col.clone(), + right_col.clone(), + ( + Operator::Plus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(1 as $TYPE)), + ScalarValue::$SCALAR(Some(11 as $TYPE)), + ScalarValue::$SCALAR(Some(3 as $TYPE)), + ScalarValue::$SCALAR(Some(33 as $TYPE)), + (greater_op, less_op), + ); + // l > r + 10 AND r > l - 30 + let l_gt_r = 10 as $TYPE; + let r_gt_l = -30 as $TYPE; + $GENERATE_CASE_FUNC_NAME::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 10 AND l < r + 30 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + $GENERATE_CASE_FUNC_NAME::( + expr, left_col, right_col, seed, l_lt_r, r_lt_l, + ) + } + }; } - #[rstest] - #[test] - fn case_4( - #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, - ) -> Result<()> { - let left_col = Arc::new(Column::new("left_watermark", 0)); - let right_col = Arc::new(Column::new("right_watermark", 0)); - // left_watermark - 10 > right_watermark - 5 AND left_watermark - 3 < right_watermark + 10 - let expr = gen_conjunctive_numeric_expr( - left_col.clone(), - right_col.clone(), - Operator::Minus, - Operator::Minus, - Operator::Minus, - Operator::Plus, - 10, - 5, - 3, - 10, - (Operator::Gt, Operator::Lt), - ); - // l > r + 5 AND r > l - 13 - let l_gt_r = 5; - let r_gt_l = -13; - generate_case::( - expr.clone(), - left_col.clone(), - right_col.clone(), - seed, - l_gt_r, - r_gt_l, - )?; - // Descending tests - // r < l - 5 AND l < r + 13 - let r_lt_l = -l_gt_r; - let l_lt_r = -r_gt_l; - generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; - Ok(()) + integer_float_case_1!(case_1_i32, generate_case_i32, i32, Int32); + integer_float_case_1!(case_1_i64, generate_case_i64, i64, Int64); + integer_float_case_1!(case_1_f64, generate_case_f64, f64, Float64); + integer_float_case_1!(case_1_f32, generate_case_f32, f32, Float32); + + macro_rules! integer_float_case_2 { + ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => { + #[rstest] + #[test] + fn $TEST_FUNC_NAME( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)] + seed: u64, + #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator, + #[values(Operator::Lt, Operator::LtEq)] less_op: Operator, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + + // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 + let expr = gen_conjunctive_numerical_expr( + left_col.clone(), + right_col.clone(), + ( + Operator::Minus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(1 as $TYPE)), + ScalarValue::$SCALAR(Some(5 as $TYPE)), + ScalarValue::$SCALAR(Some(3 as $TYPE)), + ScalarValue::$SCALAR(Some(10 as $TYPE)), + (greater_op, less_op), + ); + // l > r + 6 AND r > l - 7 + let l_gt_r = 6 as $TYPE; + let r_gt_l = -7 as $TYPE; + $GENERATE_CASE_FUNC_NAME::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 6 AND l < r + 7 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + $GENERATE_CASE_FUNC_NAME::( + expr, left_col, right_col, seed, l_lt_r, r_lt_l, + ) + } + }; } - #[rstest] - #[test] - fn case_5( - #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, - ) -> Result<()> { - let left_col = Arc::new(Column::new("left_watermark", 0)); - let right_col = Arc::new(Column::new("right_watermark", 0)); - // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3 - - let expr = gen_conjunctive_numeric_expr( - left_col.clone(), - right_col.clone(), - Operator::Minus, - Operator::Minus, - Operator::Minus, - Operator::Minus, - 10, - 5, - 30, - 3, - (Operator::Gt, Operator::Lt), - ); - // l > r + 5 AND r > l - 27 - let l_gt_r = 5; - let r_gt_l = -27; - generate_case::( - expr.clone(), - left_col.clone(), - right_col.clone(), - seed, - l_gt_r, - r_gt_l, - )?; - // Descending tests - // r < l - 5 AND l < r + 27 - let r_lt_l = -l_gt_r; - let l_lt_r = -r_gt_l; - generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; - - Ok(()) + integer_float_case_2!(case_2_i32, generate_case_i32, i32, Int32); + integer_float_case_2!(case_2_i64, generate_case_i64, i64, Int64); + integer_float_case_2!(case_2_f64, generate_case_f64, f64, Float64); + integer_float_case_2!(case_2_f32, generate_case_f32, f32, Float32); + + macro_rules! integer_float_case_3 { + ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => { + #[rstest] + #[test] + fn $TEST_FUNC_NAME( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)] + seed: u64, + #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator, + #[values(Operator::Lt, Operator::LtEq)] less_op: Operator, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + + // left_watermark - 1 > right_watermark + 5 AND left_watermark - 3 < right_watermark + 10 + let expr = gen_conjunctive_numerical_expr( + left_col.clone(), + right_col.clone(), + ( + Operator::Minus, + Operator::Plus, + Operator::Minus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(1 as $TYPE)), + ScalarValue::$SCALAR(Some(5 as $TYPE)), + ScalarValue::$SCALAR(Some(3 as $TYPE)), + ScalarValue::$SCALAR(Some(10 as $TYPE)), + (greater_op, less_op), + ); + // l > r + 6 AND r > l - 13 + let l_gt_r = 6 as $TYPE; + let r_gt_l = -13 as $TYPE; + $GENERATE_CASE_FUNC_NAME::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 6 AND l < r + 13 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + $GENERATE_CASE_FUNC_NAME::( + expr, left_col, right_col, seed, l_lt_r, r_lt_l, + ) + } + }; } - #[rstest] - #[test] - fn case_6( - #[values(0, 1, 2, 123, 756, 63, 345, 6443, 12341, 142, 123, 8900)] seed: u64, - ) -> Result<()> { - let left_col = Arc::new(Column::new("left_watermark", 0)); - let right_col = Arc::new(Column::new("right_watermark", 0)); - // left_watermark - 1 >= right_watermark + 5 AND left_watermark - 10 <= right_watermark + 3 - - let expr = gen_conjunctive_numeric_expr( - left_col.clone(), - right_col.clone(), - Operator::Minus, - Operator::Plus, - Operator::Minus, - Operator::Plus, - 1, - 5, - 10, - 3, - (Operator::GtEq, Operator::LtEq), - ); - // l >= r + 6 AND r >= l - 13 - let l_gt_r = 6; - let r_gt_l = -13; - - generate_case::( - expr.clone(), - left_col.clone(), - right_col.clone(), - seed, - l_gt_r, - r_gt_l, - )?; - generate_case::(expr, left_col, right_col, seed, l_gt_r, r_gt_l)?; - - Ok(()) + integer_float_case_3!(case_3_i32, generate_case_i32, i32, Int32); + integer_float_case_3!(case_3_i64, generate_case_i64, i64, Int64); + integer_float_case_3!(case_3_f64, generate_case_f64, f64, Float64); + integer_float_case_3!(case_3_f32, generate_case_f32, f32, Float32); + + macro_rules! integer_float_case_4 { + ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => { + #[rstest] + #[test] + fn $TEST_FUNC_NAME( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)] + seed: u64, + #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator, + #[values(Operator::Lt, Operator::LtEq)] less_op: Operator, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + + // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3 + let expr = gen_conjunctive_numerical_expr( + left_col.clone(), + right_col.clone(), + ( + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(10 as $TYPE)), + ScalarValue::$SCALAR(Some(5 as $TYPE)), + ScalarValue::$SCALAR(Some(3 as $TYPE)), + ScalarValue::$SCALAR(Some(10 as $TYPE)), + (greater_op, less_op), + ); + // l > r + 5 AND r > l - 13 + let l_gt_r = 5 as $TYPE; + let r_gt_l = -13 as $TYPE; + $GENERATE_CASE_FUNC_NAME::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 5 AND l < r + 13 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + $GENERATE_CASE_FUNC_NAME::( + expr, left_col, right_col, seed, l_lt_r, r_lt_l, + ) + } + }; } - #[rstest] - #[test] - fn case_7( - #[values(0, 1, 2, 123, 77, 93, 104, 624, 115, 613, 8365, 9345)] seed: u64, - ) -> Result<()> { - let left_col = Arc::new(Column::new("left_watermark", 0)); - let right_col = Arc::new(Column::new("right_watermark", 0)); - // left_watermark + 4 >= right_watermark + 5 AND left_watermark - 20 < right_watermark - 5 - - let expr = gen_conjunctive_numeric_expr( - left_col.clone(), - right_col.clone(), - Operator::Plus, - Operator::Plus, - Operator::Minus, - Operator::Minus, - 4, - 5, - 20, - 5, - (Operator::GtEq, Operator::Lt), - ); - // l >= r + 1 AND r > l - 15 - let l_gt_r = 1; - let r_gt_l = -15; - generate_case::( - expr.clone(), - left_col.clone(), - right_col.clone(), - seed, - l_gt_r, - r_gt_l, - )?; - // Descending tests - // r < l - 5 AND l < r + 27 - let r_lt_l = -l_gt_r; - let l_lt_r = -r_gt_l; - generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; - - Ok(()) + integer_float_case_4!(case_4_i32, generate_case_i32, i32, Int32); + integer_float_case_4!(case_4_i64, generate_case_i64, i64, Int64); + integer_float_case_4!(case_4_f64, generate_case_f64, f64, Float64); + integer_float_case_4!(case_4_f32, generate_case_f32, f32, Float32); + + macro_rules! integer_float_case_5 { + ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => { + #[rstest] + #[test] + fn $TEST_FUNC_NAME( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)] + seed: u64, + #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator, + #[values(Operator::Lt, Operator::LtEq)] less_op: Operator, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + + // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3 + let expr = gen_conjunctive_numerical_expr( + left_col.clone(), + right_col.clone(), + ( + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ), + ScalarValue::$SCALAR(Some(10 as $TYPE)), + ScalarValue::$SCALAR(Some(5 as $TYPE)), + ScalarValue::$SCALAR(Some(30 as $TYPE)), + ScalarValue::$SCALAR(Some(3 as $TYPE)), + (greater_op, less_op), + ); + // l > r + 5 AND r > l - 27 + let l_gt_r = 5 as $TYPE; + let r_gt_l = -27 as $TYPE; + $GENERATE_CASE_FUNC_NAME::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 5 AND l < r + 27 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + $GENERATE_CASE_FUNC_NAME::( + expr, left_col, right_col, seed, l_lt_r, r_lt_l, + ) + } + }; } - #[rstest] - #[test] - fn case_8( - #[values(0, 1, 2, 24, 53, 412, 364, 345, 737, 1010, 52, 1554)] seed: u64, - ) -> Result<()> { - let left_col = Arc::new(Column::new("left_watermark", 0)); - let right_col = Arc::new(Column::new("right_watermark", 0)); - // left_watermark + 4 >= right_watermark + 5 AND left_watermark - 20 < right_watermark - 5 - - let expr = gen_conjunctive_numeric_expr( - left_col.clone(), - right_col.clone(), - Operator::Plus, - Operator::Plus, - Operator::Minus, - Operator::Minus, - 4, - 5, - 20, - 5, - (Operator::Gt, Operator::LtEq), - ); - // l >= r + 1 AND r > l - 15 - let l_gt_r = 1; - let r_gt_l = -15; - generate_case::( - expr.clone(), - left_col.clone(), - right_col.clone(), - seed, - l_gt_r, - r_gt_l, - )?; - // Descending tests - // r < l - 5 AND l < r + 27 - let r_lt_l = -l_gt_r; - let l_lt_r = -r_gt_l; - generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; - - Ok(()) - } + integer_float_case_5!(case_5_i32, generate_case_i32, i32, Int32); + integer_float_case_5!(case_5_i64, generate_case_i64, i64, Int64); + integer_float_case_5!(case_5_f64, generate_case_f64, f64, Float64); + integer_float_case_5!(case_5_f32, generate_case_f32, f32, Float32); #[test] fn test_gather_node_indices_dont_remove() -> Result<()> { @@ -1067,6 +997,7 @@ mod tests { assert_eq!(prev_node_count, final_node_count); Ok(()) } + #[test] fn test_gather_node_indices_remove() -> Result<()> { // Expression: a@0 + b@1 + 1 > y@0 - z@1, given a@0 + b@1. diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs index 9a4b0bfe8def..6c2d1b0f418f 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs @@ -28,6 +28,7 @@ use datafusion_expr::type_coercion::binary::coerce_types; use datafusion_expr::Operator; use crate::aggregate::min_max::{max, min}; +use crate::intervals::rounding::alter_fp_rounding_mode; /// This type represents a single endpoint of an [`Interval`]. An endpoint can /// be open or closed, denoting whether the interval includes or excludes the @@ -75,38 +76,54 @@ impl IntervalBound { /// The result is unbounded if either is; otherwise, their values are /// added. The result is closed if both original bounds are closed, or open /// otherwise. - pub fn add>(&self, other: T) -> Result { + pub fn add>( + &self, + other: T, + ) -> Result { let rhs = other.borrow(); if self.is_unbounded() || rhs.is_unbounded() { - IntervalBound::make_unbounded(coerce_types( + return IntervalBound::make_unbounded(coerce_types( &self.get_datatype(), &Operator::Plus, &rhs.get_datatype(), - )?) - } else { - self.value - .add(&rhs.value) - .map(|v| IntervalBound::new(v, self.open || rhs.open)) + )?); } + match self.get_datatype() { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(&self.value, &rhs.value, |lhs, rhs| { + lhs.add(rhs) + }) + } + _ => self.value.add(&rhs.value), + } + .map(|v| IntervalBound::new(v, self.open || rhs.open)) } /// This function subtracts the given `IntervalBound` from `self`. /// The result is unbounded if either is; otherwise, their values are /// subtracted. The result is closed if both original bounds are closed, /// or open otherwise. - pub fn sub>(&self, other: T) -> Result { + pub fn sub>( + &self, + other: T, + ) -> Result { let rhs = other.borrow(); if self.is_unbounded() || rhs.is_unbounded() { - IntervalBound::make_unbounded(coerce_types( + return IntervalBound::make_unbounded(coerce_types( &self.get_datatype(), &Operator::Minus, &rhs.get_datatype(), - )?) - } else { - self.value - .sub(&rhs.value) - .map(|v| IntervalBound::new(v, self.open || rhs.open)) + )?); + } + match self.get_datatype() { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(&self.value, &rhs.value, |lhs, rhs| { + lhs.sub(rhs) + }) + } + _ => self.value.sub(&rhs.value), } + .map(|v| IntervalBound::new(v, self.open || rhs.open)) } /// This function chooses one of the given `IntervalBound`s according to @@ -404,8 +421,8 @@ impl Interval { pub fn add>(&self, other: T) -> Result { let rhs = other.borrow(); Ok(Interval::new( - self.lower.add(&rhs.lower)?, - self.upper.add(&rhs.upper)?, + self.lower.add::(&rhs.lower)?, + self.upper.add::(&rhs.upper)?, )) } @@ -416,8 +433,8 @@ impl Interval { pub fn sub>(&self, other: T) -> Result { let rhs = other.borrow(); Ok(Interval::new( - self.lower.sub(&rhs.upper)?, - self.upper.sub(&rhs.lower)?, + self.lower.sub::(&rhs.upper)?, + self.upper.sub::(&rhs.lower)?, )) } @@ -463,6 +480,8 @@ pub fn is_datatype_supported(data_type: &DataType) -> bool { | &DataType::UInt32 | &DataType::UInt16 | &DataType::UInt8 + | &DataType::Float64 + | &DataType::Float32 ) } @@ -1041,7 +1060,7 @@ mod tests { // This function tests if valid constructions produce standardized objects // ([false, false], [false, true], [true, true]) for boolean intervals. #[test] - fn non_standard_interval_constructs() -> Result<()> { + fn non_standard_interval_constructs() { let cases = vec![ ( IntervalBound::new(Boolean(None), true), @@ -1078,6 +1097,80 @@ mod tests { for case in cases { assert_eq!(Interval::new(case.0, case.1), case.2) } - Ok(()) + } + + macro_rules! capture_mode_change { + ($TYPE:ty) => { + paste::item! { + capture_mode_change_helper!([], + [], + $TYPE); + } + }; + } + + macro_rules! capture_mode_change_helper { + ($TEST_FN_NAME:ident, $CREATE_FN_NAME:ident, $TYPE:ty) => { + fn $CREATE_FN_NAME(lower: $TYPE, upper: $TYPE) -> Interval { + Interval::make(Some(lower as $TYPE), Some(upper as $TYPE), (true, true)) + } + + fn $TEST_FN_NAME(input: ($TYPE, $TYPE), expect_low: bool, expect_high: bool) { + assert!(expect_low || expect_high); + let interval1 = $CREATE_FN_NAME(input.0, input.0); + let interval2 = $CREATE_FN_NAME(input.1, input.1); + let result = interval1.add(&interval2).unwrap(); + let without_fe = $CREATE_FN_NAME(input.0 + input.1, input.0 + input.1); + assert!( + (!expect_low || result.lower.value < without_fe.lower.value) + && (!expect_high || result.upper.value > without_fe.upper.value) + ); + } + }; + } + + capture_mode_change!(f32); + capture_mode_change!(f64); + + #[cfg(all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") + ))] + #[test] + fn test_add_intervals_lower_affected_f32() { + // Lower is affected + let lower = f32::from_bits(1073741887); //1000000000000000000000000111111 + let upper = f32::from_bits(1098907651); //1000001100000000000000000000011 + capture_mode_change_f32((lower, upper), true, false); + + // Upper is affected + let lower = f32::from_bits(1072693248); //111111111100000000000000000000 + let upper = f32::from_bits(715827883); //101010101010101010101010101011 + capture_mode_change_f32((lower, upper), false, true); + + // Lower is affected + let lower = 1.0; // 0x3FF0000000000000 + let upper = 0.3; // 0x3FD3333333333333 + capture_mode_change_f64((lower, upper), true, false); + + // Upper is affected + let lower = 1.4999999999999998; // 0x3FF7FFFFFFFFFFFF + let upper = 0.000_000_000_000_000_022_044_604_925_031_31; // 0x3C796A6B413BB21F + capture_mode_change_f64((lower, upper), false, true); + } + + #[cfg(any( + not(any(target_arch = "x86_64", target_arch = "aarch64")), + target_os = "windows" + ))] + #[test] + fn test_next_impl_add_intervals_f64() { + let lower = 1.5; + let upper = 1.5; + capture_mode_change_f64((lower, upper), true, true); + + let lower = 1.5; + let upper = 1.5; + capture_mode_change_f32((lower, upper), true, true); } } diff --git a/datafusion/physical-expr/src/intervals/mod.rs b/datafusion/physical-expr/src/intervals/mod.rs index 9883ba15b2e7..a9255752fea4 100644 --- a/datafusion/physical-expr/src/intervals/mod.rs +++ b/datafusion/physical-expr/src/intervals/mod.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. -//! Interval calculations -//! +//! Interval arithmetic and constraint propagation library pub mod cp_solver; pub mod interval_aritmetic; +pub mod rounding; pub mod test_utils; pub use cp_solver::{check_support, ExprIntervalGraph}; diff --git a/datafusion/physical-expr/src/intervals/rounding.rs b/datafusion/physical-expr/src/intervals/rounding.rs new file mode 100644 index 000000000000..06c4f9e8a957 --- /dev/null +++ b/datafusion/physical-expr/src/intervals/rounding.rs @@ -0,0 +1,401 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Floating point rounding mode utility library +//! TODO: Remove this custom implementation and the "libc" dependency when +//! floating-point rounding mode manipulation functions become available +//! in Rust. + +use std::ops::{Add, BitAnd, Sub}; + +use datafusion_common::Result; +use datafusion_common::ScalarValue; + +// Define constants for ARM +#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))] +const FE_UPWARD: i32 = 0x00400000; +#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))] +const FE_DOWNWARD: i32 = 0x00800000; + +// Define constants for x86_64 +#[cfg(all(target_arch = "x86_64", not(target_os = "windows")))] +const FE_UPWARD: i32 = 0x0800; +#[cfg(all(target_arch = "x86_64", not(target_os = "windows")))] +const FE_DOWNWARD: i32 = 0x0400; + +#[cfg(all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") +))] +extern crate libc; + +#[cfg(all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") +))] +extern "C" { + fn fesetround(round: i32); + fn fegetround() -> i32; +} + +/// A trait to manipulate floating-point types with bitwise operations. +/// Provides functions to convert a floating-point value to/from its bitwise +/// representation as well as utility methods to handle special values. +pub trait FloatBits { + /// The integer type used for bitwise operations. + type Item: Copy + + PartialEq + + BitAnd + + Add + + Sub; + + /// The smallest positive floating-point value representable by this type. + const TINY_BITS: Self::Item; + + /// The smallest (in magnitude) negative floating-point value representable by this type. + const NEG_TINY_BITS: Self::Item; + + /// A mask to clear the sign bit of the floating-point value's bitwise representation. + const CLEAR_SIGN_MASK: Self::Item; + + /// The integer value 1, used in bitwise operations. + const ONE: Self::Item; + + /// The integer value 0, used in bitwise operations. + const ZERO: Self::Item; + + /// Converts the floating-point value to its bitwise representation. + fn to_bits(self) -> Self::Item; + + /// Converts the bitwise representation to the corresponding floating-point value. + fn from_bits(bits: Self::Item) -> Self; + + /// Returns true if the floating-point value is NaN (not a number). + fn float_is_nan(self) -> bool; + + /// Returns the positive infinity value for this floating-point type. + fn infinity() -> Self; + + /// Returns the negative infinity value for this floating-point type. + fn neg_infinity() -> Self; +} + +impl FloatBits for f32 { + type Item = u32; + const TINY_BITS: u32 = 0x1; // Smallest positive f32. + const NEG_TINY_BITS: u32 = 0x8000_0001; // Smallest (in magnitude) negative f32. + const CLEAR_SIGN_MASK: u32 = 0x7fff_ffff; + const ONE: Self::Item = 1; + const ZERO: Self::Item = 0; + + fn to_bits(self) -> Self::Item { + self.to_bits() + } + + fn from_bits(bits: Self::Item) -> Self { + f32::from_bits(bits) + } + + fn float_is_nan(self) -> bool { + self.is_nan() + } + + fn infinity() -> Self { + f32::INFINITY + } + + fn neg_infinity() -> Self { + f32::NEG_INFINITY + } +} + +impl FloatBits for f64 { + type Item = u64; + const TINY_BITS: u64 = 0x1; // Smallest positive f64. + const NEG_TINY_BITS: u64 = 0x8000_0000_0000_0001; // Smallest (in magnitude) negative f64. + const CLEAR_SIGN_MASK: u64 = 0x7fff_ffff_ffff_ffff; + const ONE: Self::Item = 1; + const ZERO: Self::Item = 0; + + fn to_bits(self) -> Self::Item { + self.to_bits() + } + + fn from_bits(bits: Self::Item) -> Self { + f64::from_bits(bits) + } + + fn float_is_nan(self) -> bool { + self.is_nan() + } + + fn infinity() -> Self { + f64::INFINITY + } + + fn neg_infinity() -> Self { + f64::NEG_INFINITY + } +} + +/// Returns the next representable floating-point value greater than the input value. +/// +/// This function takes a floating-point value that implements the FloatBits trait, +/// calculates the next representable value greater than the input, and returns it. +/// +/// If the input value is NaN or positive infinity, the function returns the input value. +/// +/// # Examples +/// +/// ``` +/// use datafusion_physical_expr::intervals::rounding::next_up; +/// +/// let f: f32 = 1.0; +/// let next_f = next_up(f); +/// assert_eq!(next_f, 1.0000001); +/// ``` +#[allow(dead_code)] +pub fn next_up(float: F) -> F { + let bits = float.to_bits(); + if float.float_is_nan() || bits == F::infinity().to_bits() { + return float; + } + + let abs = bits & F::CLEAR_SIGN_MASK; + let next_bits = if abs == F::ZERO { + F::TINY_BITS + } else if bits == abs { + bits + F::ONE + } else { + bits - F::ONE + }; + F::from_bits(next_bits) +} + +/// Returns the next representable floating-point value smaller than the input value. +/// +/// This function takes a floating-point value that implements the FloatBits trait, +/// calculates the next representable value smaller than the input, and returns it. +/// +/// If the input value is NaN or negative infinity, the function returns the input value. +/// +/// # Examples +/// +/// ``` +/// use datafusion_physical_expr::intervals::rounding::next_down; +/// +/// let f: f32 = 1.0; +/// let next_f = next_down(f); +/// assert_eq!(next_f, 0.99999994); +/// ``` +#[allow(dead_code)] +pub fn next_down(float: F) -> F { + let bits = float.to_bits(); + if float.float_is_nan() || bits == F::neg_infinity().to_bits() { + return float; + } + let abs = bits & F::CLEAR_SIGN_MASK; + let next_bits = if abs == F::ZERO { + F::NEG_TINY_BITS + } else if bits == abs { + bits - F::ONE + } else { + bits + F::ONE + }; + F::from_bits(next_bits) +} + +#[cfg(any( + not(any(target_arch = "x86_64", target_arch = "aarch64")), + target_os = "windows" +))] +fn alter_fp_rounding_mode_conservative( + lhs: &ScalarValue, + rhs: &ScalarValue, + operation: F, +) -> Result +where + F: FnOnce(&ScalarValue, &ScalarValue) -> Result, +{ + let mut result = operation(lhs, rhs)?; + match &mut result { + ScalarValue::Float64(Some(value)) => { + if UPPER { + *value = next_up(*value) + } else { + *value = next_down(*value) + } + } + ScalarValue::Float32(Some(value)) => { + if UPPER { + *value = next_up(*value) + } else { + *value = next_down(*value) + } + } + _ => {} + }; + Ok(result) +} + +pub fn alter_fp_rounding_mode( + lhs: &ScalarValue, + rhs: &ScalarValue, + operation: F, +) -> Result +where + F: FnOnce(&ScalarValue, &ScalarValue) -> Result, +{ + #[cfg(all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") + ))] + unsafe { + let current = fegetround(); + fesetround(if UPPER { FE_UPWARD } else { FE_DOWNWARD }); + let result = operation(lhs, rhs); + fesetround(current); + result + } + #[cfg(any( + not(any(target_arch = "x86_64", target_arch = "aarch64")), + target_os = "windows" + ))] + alter_fp_rounding_mode_conservative::(lhs, rhs, operation) +} + +#[cfg(test)] +mod tests { + use super::{next_down, next_up}; + + #[test] + fn test_next_down() { + let x = 1.0f64; + // Clamp value into range [0, 1). + let clamped = x.clamp(0.0, next_down(1.0f64)); + assert!(clamped < 1.0); + assert_eq!(next_up(clamped), 1.0); + } + + #[test] + fn test_next_up_small_positive() { + let value: f64 = 1.0; + let result = next_up(value); + assert_eq!(result, 1.0000000000000002); + } + + #[test] + fn test_next_up_small_negative() { + let value: f64 = -1.0; + let result = next_up(value); + assert_eq!(result, -0.9999999999999999); + } + + #[test] + fn test_next_up_pos_infinity() { + let value: f64 = f64::INFINITY; + let result = next_up(value); + assert_eq!(result, f64::INFINITY); + } + + #[test] + fn test_next_up_nan() { + let value: f64 = f64::NAN; + let result = next_up(value); + assert!(result.is_nan()); + } + + #[test] + fn test_next_down_small_positive() { + let value: f64 = 1.0; + let result = next_down(value); + assert_eq!(result, 0.9999999999999999); + } + + #[test] + fn test_next_down_small_negative() { + let value: f64 = -1.0; + let result = next_down(value); + assert_eq!(result, -1.0000000000000002); + } + + #[test] + fn test_next_down_neg_infinity() { + let value: f64 = f64::NEG_INFINITY; + let result = next_down(value); + assert_eq!(result, f64::NEG_INFINITY); + } + + #[test] + fn test_next_down_nan() { + let value: f64 = f64::NAN; + let result = next_down(value); + assert!(result.is_nan()); + } + + #[test] + fn test_next_up_small_positive_f32() { + let value: f32 = 1.0; + let result = next_up(value); + assert_eq!(result, 1.0000001); + } + + #[test] + fn test_next_up_small_negative_f32() { + let value: f32 = -1.0; + let result = next_up(value); + assert_eq!(result, -0.99999994); + } + + #[test] + fn test_next_up_pos_infinity_f32() { + let value: f32 = f32::INFINITY; + let result = next_up(value); + assert_eq!(result, f32::INFINITY); + } + + #[test] + fn test_next_up_nan_f32() { + let value: f32 = f32::NAN; + let result = next_up(value); + assert!(result.is_nan()); + } + #[test] + fn test_next_down_small_positive_f32() { + let value: f32 = 1.0; + let result = next_down(value); + assert_eq!(result, 0.99999994); + } + #[test] + fn test_next_down_small_negative_f32() { + let value: f32 = -1.0; + let result = next_down(value); + assert_eq!(result, -1.0000001); + } + #[test] + fn test_next_down_neg_infinity_f32() { + let value: f32 = f32::NEG_INFINITY; + let result = next_down(value); + assert_eq!(result, f32::NEG_INFINITY); + } + #[test] + fn test_next_down_nan_f32() { + let value: f32 = f32::NAN; + let result = next_down(value); + assert!(result.is_nan()); + } +} diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs b/datafusion/physical-expr/src/intervals/test_utils.rs index f233b246a73c..6bbf74dc7d7f 100644 --- a/datafusion/physical-expr/src/intervals/test_utils.rs +++ b/datafusion/physical-expr/src/intervals/test_utils.rs @@ -29,40 +29,31 @@ use datafusion_expr::Operator; /// This test function generates a conjunctive statement with two numeric /// terms with the following form: /// left_col (op_1) a >/>= right_col (op_2) b AND left_col (op_3) c , right_col: Arc, - op_1: Operator, - op_2: Operator, - op_3: Operator, - op_4: Operator, - a: i32, - b: i32, - c: i32, - d: i32, + op: (Operator, Operator, Operator, Operator), + a: ScalarValue, + b: ScalarValue, + c: ScalarValue, + d: ScalarValue, bounds: (Operator, Operator), ) -> Arc { + let (op_1, op_2, op_3, op_4) = op; let left_and_1 = Arc::new(BinaryExpr::new( left_col.clone(), op_1, - Arc::new(Literal::new(ScalarValue::Int32(Some(a)))), + Arc::new(Literal::new(a)), )); let left_and_2 = Arc::new(BinaryExpr::new( right_col.clone(), op_2, - Arc::new(Literal::new(ScalarValue::Int32(Some(b)))), - )); - - let right_and_1 = Arc::new(BinaryExpr::new( - left_col, - op_3, - Arc::new(Literal::new(ScalarValue::Int32(Some(c)))), - )); - let right_and_2 = Arc::new(BinaryExpr::new( - right_col, - op_4, - Arc::new(Literal::new(ScalarValue::Int32(Some(d)))), + Arc::new(Literal::new(b)), )); + let right_and_1 = + Arc::new(BinaryExpr::new(left_col, op_3, Arc::new(Literal::new(c)))); + let right_and_2 = + Arc::new(BinaryExpr::new(right_col, op_4, Arc::new(Literal::new(d)))); let (greater_op, less_op) = bounds; let left_expr = Arc::new(BinaryExpr::new(left_and_1, greater_op, left_and_2));