diff --git a/src/query/expression/src/aggregate/group_hash.rs b/src/query/expression/src/aggregate/group_hash.rs index 660fc11ace3b2..15f36900af8fd 100644 --- a/src/query/expression/src/aggregate/group_hash.rs +++ b/src/query/expression/src/aggregate/group_hash.rs @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use databend_common_arrow::arrow::buffer::Buffer; +use databend_common_arrow::arrow::types::Index; use databend_common_base::base::OrderedFloat; +use databend_common_exception::Result; use ethnum::i256; -use crate::types::decimal::DecimalType; -use crate::types::geometry::GeometryType; use crate::types::AnyType; use crate::types::ArgType; use crate::types::BinaryType; @@ -24,17 +25,29 @@ use crate::types::BitmapType; use crate::types::BooleanType; use crate::types::DataType; use crate::types::DateType; +use crate::types::DecimalColumn; use crate::types::DecimalDataType; +use crate::types::DecimalScalar; +use crate::types::DecimalType; +use crate::types::GeographyType; +use crate::types::GeometryType; +use crate::types::NumberColumn; use crate::types::NumberDataType; +use crate::types::NumberScalar; use crate::types::NumberType; use crate::types::StringType; use crate::types::TimestampType; use crate::types::ValueType; use crate::types::VariantType; +use crate::visitor::ValueVisitor; +use crate::with_decimal_type; use crate::with_number_mapped_type; +use crate::with_number_type; use crate::Column; use crate::InputColumns; +use crate::Scalar; use crate::ScalarRef; +use crate::Value; const NULL_HASH_VAL: u64 = 0xd1cefa08eb382d69; @@ -47,15 +60,6 @@ pub fn group_hash_columns(cols: InputColumns, values: &mut [u64]) { } } -pub fn group_hash_columns_slice(cols: &[Column], values: &mut [u64]) { - debug_assert!(!cols.is_empty()); - let mut iter = cols.iter(); - combine_group_hash_column::(iter.next().unwrap(), values); - for col in iter { - combine_group_hash_column::(col, values); - } -} - pub fn combine_group_hash_column(c: &Column, values: &mut [u64]) { match c.data_type() { DataType::Null => {} @@ -82,6 +86,9 @@ pub fn combine_group_hash_column(c: &Column, values: &mut DataType::Bitmap => combine_group_hash_string_column::(c, values), DataType::Variant => combine_group_hash_string_column::(c, values), DataType::Geometry => combine_group_hash_string_column::(c, values), + DataType::Geography => { + combine_group_hash_string_column::(c, values) + } DataType::Nullable(_) => { let col = c.as_nullable().unwrap(); if IS_FIRST { @@ -149,6 +156,199 @@ fn combine_group_hash_string_column( } } +pub fn group_hash_value_spread( + indices: &[I], + value: Value, + first: bool, + target: &mut [u64], +) -> Result<()> { + if first { + let mut v = IndexHashVisitor::::new(indices, target); + v.visit_value(value) + } else { + let mut v = IndexHashVisitor::::new(indices, target); + v.visit_value(value) + } +} + +struct IndexHashVisitor<'a, 'b, const IS_FIRST: bool, I> +where I: Index +{ + indices: &'a [I], + target: &'b mut [u64], +} + +impl<'a, 'b, const IS_FIRST: bool, I> IndexHashVisitor<'a, 'b, IS_FIRST, I> +where I: Index +{ + fn new(indices: &'a [I], target: &'b mut [u64]) -> Self { + Self { indices, target } + } +} + +impl<'a, 'b, const IS_FIRST: bool, I> ValueVisitor for IndexHashVisitor<'a, 'b, IS_FIRST, I> +where I: Index +{ + fn visit_scalar(&mut self, scalar: Scalar) -> Result<()> { + let hash = match scalar { + Scalar::EmptyArray | Scalar::EmptyMap => return Ok(()), + Scalar::Null => NULL_HASH_VAL, + Scalar::Number(v) => with_number_type!(|NUM_TYPE| match v { + NumberScalar::NUM_TYPE(v) => v.agg_hash(), + }), + Scalar::Decimal(v) => match v { + DecimalScalar::Decimal128(v, _) => v.agg_hash(), + DecimalScalar::Decimal256(v, _) => v.agg_hash(), + }, + Scalar::Timestamp(v) => v.agg_hash(), + Scalar::Date(v) => v.agg_hash(), + Scalar::Boolean(v) => v.agg_hash(), + Scalar::Binary(v) => v.agg_hash(), + Scalar::String(v) => v.as_bytes().agg_hash(), + Scalar::Variant(v) => v.agg_hash(), + Scalar::Geometry(v) => v.agg_hash(), + Scalar::Geography(v) => v.0.agg_hash(), + v => v.as_ref().agg_hash(), + }; + self.visit_indices(|_| hash) + } + + fn visit_null(&mut self, _len: usize) -> Result<()> { + Ok(()) + } + + fn visit_empty_array(&mut self, _len: usize) -> Result<()> { + Ok(()) + } + + fn visit_empty_map(&mut self, _len: usize) -> Result<()> { + Ok(()) + } + + fn visit_any_number(&mut self, column: crate::types::NumberColumn) -> Result<()> { + with_number_type!(|NUM_TYPE| match column { + NumberColumn::NUM_TYPE(buffer) => { + let buffer = buffer.as_ref(); + self.visit_indices(|i| buffer[i.to_usize()].agg_hash()) + } + }) + } + + fn visit_timestamp(&mut self, buffer: Buffer) -> Result<()> { + self.visit_number(buffer) + } + + fn visit_date(&mut self, buffer: Buffer) -> Result<()> { + self.visit_number(buffer) + } + + fn visit_any_decimal(&mut self, column: DecimalColumn) -> Result<()> { + with_decimal_type!(|DECIMAL_TYPE| match column { + DecimalColumn::DECIMAL_TYPE(buffer, _) => { + let buffer = buffer.as_ref(); + self.visit_indices(|i| buffer[i.to_usize()].agg_hash()) + } + }) + } + + fn visit_binary(&mut self, column: crate::types::BinaryColumn) -> Result<()> { + self.visit_indices(|i| column.index(i.to_usize()).unwrap().agg_hash()) + } + + fn visit_variant(&mut self, column: crate::types::BinaryColumn) -> Result<()> { + self.visit_binary(column) + } + + fn visit_bitmap(&mut self, column: crate::types::BinaryColumn) -> Result<()> { + self.visit_binary(column) + } + + fn visit_string(&mut self, column: crate::types::StringColumn) -> Result<()> { + self.visit_indices(|i| column.index(i.to_usize()).unwrap().as_bytes().agg_hash()) + } + + fn visit_boolean( + &mut self, + bitmap: databend_common_arrow::arrow::bitmap::Bitmap, + ) -> Result<()> { + self.visit_indices(|i| bitmap.get(i.to_usize()).unwrap().agg_hash()) + } + + fn visit_geometry(&mut self, column: crate::types::BinaryColumn) -> Result<()> { + self.visit_binary(column) + } + + fn visit_geography(&mut self, column: crate::types::GeographyColumn) -> Result<()> { + self.visit_binary(column.0) + } + + fn visit_nullable(&mut self, column: Box>) -> Result<()> { + let indices = self + .indices + .iter() + .cloned() + .filter(|&i| { + let i = i.to_usize(); + let ok = column.validity.get(i).unwrap(); + if !ok { + let val = &mut self.target[i]; + if IS_FIRST { + *val = NULL_HASH_VAL; + } else { + *val = merge_hash(*val, NULL_HASH_VAL); + } + } + ok + }) + .collect::>(); + if IS_FIRST { + let mut v = IndexHashVisitor::::new(&indices, self.target); + v.visit_column(column.column) + } else { + let mut v = IndexHashVisitor::::new(&indices, self.target); + v.visit_column(column.column) + } + } + + fn visit_typed_column(&mut self, column: T::Column) -> Result<()> { + self.visit_indices(|i| { + let x = T::upcast_scalar(T::to_owned_scalar( + T::index_column(&column, i.to_usize()).unwrap(), + )); + x.as_ref().agg_hash() + }) + } +} + +impl<'a, 'b, const IS_FIRST: bool, I> IndexHashVisitor<'a, 'b, IS_FIRST, I> +where I: Index +{ + fn visit_indices(&mut self, do_hash: F) -> Result<()> + where F: Fn(&I) -> u64 { + self.visit_indices_update(|i, val| { + let hash = do_hash(i); + if IS_FIRST { + *val = hash; + } else { + *val = merge_hash(*val, hash); + } + }) + } + + fn visit_indices_update(&mut self, update: F) -> Result<()> + where F: Fn(&I, &mut u64) { + for i in self.indices { + let val = &mut self.target[i.to_usize()]; + update(i, val); + } + Ok(()) + } +} + +fn merge_hash(a: u64, b: u64) -> u64 { + a.wrapping_mul(NULL_HASH_VAL) ^ b +} + pub trait AggHash { fn agg_hash(&self) -> u64; } @@ -263,3 +463,142 @@ impl AggHash for ScalarRef<'_> { self.to_string().as_bytes().agg_hash() } } + +#[cfg(test)] +mod tests { + use databend_common_arrow::arrow::bitmap::Bitmap; + + use super::*; + use crate::types::ArgType; + use crate::types::Int32Type; + use crate::types::NullableColumn; + use crate::types::NullableType; + use crate::types::StringType; + use crate::BlockEntry; + use crate::DataBlock; + use crate::FromData; + use crate::Scalar; + use crate::Value; + + fn merge_hash_slice(ls: &[u64]) -> u64 { + ls.iter().cloned().reduce(merge_hash).unwrap() + } + + #[test] + fn test_value_spread() -> Result<()> { + let data = DataBlock::new( + vec![ + BlockEntry::new( + Int32Type::data_type(), + Value::Column(Int32Type::from_data(vec![3, 1, 2, 2, 4, 3, 7, 0, 3])), + ), + BlockEntry::new( + StringType::data_type(), + Value::Scalar(Scalar::String("a".to_string())), + ), + BlockEntry::new( + Int32Type::data_type(), + Value::Column(Int32Type::from_data(vec![3, 1, 3, 2, 2, 3, 4, 3, 3])), + ), + BlockEntry::new( + StringType::data_type(), + Value::Column(StringType::from_data(vec![ + "a", "b", "c", "d", "e", "f", "g", "h", "i", + ])), + ), + ], + 9, + ); + data.check_valid()?; + + { + let mut target = vec![0; data.num_rows()]; + for (i, entry) in data.columns().iter().enumerate() { + let indices = [0, 3, 8]; + group_hash_value_spread(&indices, entry.value.to_owned(), i == 0, &mut target)?; + } + + assert_eq!( + [ + merge_hash_slice(&[ + 3.agg_hash(), + b"a".agg_hash(), + 3.agg_hash(), + b"a".agg_hash(), + ]), + 0, + 0, + merge_hash_slice(&[ + 2.agg_hash(), + b"a".agg_hash(), + 2.agg_hash(), + b"d".agg_hash(), + ]), + 0, + 0, + 0, + 0, + merge_hash_slice(&[ + 3.agg_hash(), + b"a".agg_hash(), + 3.agg_hash(), + b"i".agg_hash(), + ]), + ] + .as_slice(), + &target + ); + } + + { + let c = Int32Type::from_data(vec![3, 1, 2]); + let c = NullableColumn::::new(c, Bitmap::from([true, true, false])); + let nc = NullableType::::upcast_column(c); + + let indices = [0, 1, 2]; + let mut target = vec![0; 3]; + group_hash_value_spread( + &indices, + Value::::Column(nc.clone()), + true, + &mut target, + )?; + + assert_eq!( + [ + merge_hash_slice(&[3.agg_hash()]), + merge_hash_slice(&[1.agg_hash()]), + merge_hash_slice(&[NULL_HASH_VAL]), + ] + .as_slice(), + &target + ); + + let c = Int32Type::from_data(vec![2, 4, 3]); + group_hash_value_spread(&indices, Value::::Column(c), false, &mut target)?; + + assert_eq!( + [ + merge_hash_slice(&[3.agg_hash(), 2.agg_hash()]), + merge_hash_slice(&[1.agg_hash(), 4.agg_hash()]), + merge_hash_slice(&[NULL_HASH_VAL, 3.agg_hash()]), + ] + .as_slice(), + &target + ); + + group_hash_value_spread(&indices, Value::::Column(nc), false, &mut target)?; + + assert_eq!( + [ + merge_hash_slice(&[3.agg_hash(), 2.agg_hash(), 3.agg_hash()]), + merge_hash_slice(&[1.agg_hash(), 4.agg_hash(), 1.agg_hash()]), + merge_hash_slice(&[NULL_HASH_VAL, 3.agg_hash(), NULL_HASH_VAL]), + ] + .as_slice(), + &target + ); + } + Ok(()) + } +} diff --git a/src/query/expression/src/kernels/sort_compare.rs b/src/query/expression/src/kernels/sort_compare.rs index 9f5078c621c39..210054addf61c 100644 --- a/src/query/expression/src/kernels/sort_compare.rs +++ b/src/query/expression/src/kernels/sort_compare.rs @@ -37,6 +37,7 @@ pub struct SortCompare { current_column_index: usize, validity: Option, equality_index: Vec, + force_equality: bool, } macro_rules! do_sorter { @@ -112,12 +113,25 @@ impl SortCompare { current_column_index: 0, validity: None, equality_index, + force_equality: matches!(limit, LimitType::LimitRank(_)), + } + } + + pub fn with_force_equality(ordering_descs: Vec, rows: usize) -> Self { + Self { + rows, + limit: LimitType::None, + permutation: (0..rows as u32).collect(), + ordering_descs, + current_column_index: 0, + validity: None, + equality_index: vec![1; rows as _], + force_equality: true, } } fn need_update_equality_index(&self) -> bool { - self.current_column_index != self.ordering_descs.len() - 1 - || matches!(self.limit, LimitType::LimitRank(_)) + self.force_equality || self.current_column_index != self.ordering_descs.len() - 1 } pub fn increment_column_index(&mut self) { @@ -254,6 +268,11 @@ impl SortCompare { } } } + + pub fn equality_index(&self) -> &[u8] { + debug_assert!(self.force_equality); + &self.equality_index + } } impl ValueVisitor for SortCompare { diff --git a/src/query/expression/src/types.rs b/src/query/expression/src/types.rs index aa6b837bfa7d1..06837d55f3b46 100755 --- a/src/query/expression/src/types.rs +++ b/src/query/expression/src/types.rs @@ -57,6 +57,7 @@ pub use self::empty_map::EmptyMapType; pub use self::generic::GenericType; pub use self::geography::GeographyColumn; pub use self::geography::GeographyType; +pub use self::geometry::GeometryType; pub use self::map::MapType; pub use self::null::NullType; pub use self::nullable::NullableColumn; diff --git a/src/query/expression/src/types/geography.rs b/src/query/expression/src/types/geography.rs index b41aad37d9df9..1a559eab0895b 100644 --- a/src/query/expression/src/types/geography.rs +++ b/src/query/expression/src/types/geography.rs @@ -83,6 +83,12 @@ impl<'a> GeographyRef<'a> { } } +impl<'a> AsRef<[u8]> for GeographyRef<'a> { + fn as_ref(&self) -> &[u8] { + self.0 + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct GeographyType; diff --git a/src/query/expression/src/utils/visitor.rs b/src/query/expression/src/utils/visitor.rs index 535383ce79683..231127ac82549 100755 --- a/src/query/expression/src/utils/visitor.rs +++ b/src/query/expression/src/utils/visitor.rs @@ -36,6 +36,12 @@ pub trait ValueVisitor { self.visit_typed_column::(len) } + fn visit_any_number(&mut self, column: NumberColumn) -> Result<()> { + with_number_type!(|NUM_TYPE| match column { + NumberColumn::NUM_TYPE(b) => self.visit_number(b), + }) + } + fn visit_number( &mut self, column: as ValueType>::Column, @@ -43,6 +49,12 @@ pub trait ValueVisitor { self.visit_typed_column::>(column) } + fn visit_any_decimal(&mut self, column: DecimalColumn) -> Result<()> { + with_decimal_type!(|DECIMAL_TYPE| match column { + DecimalColumn::DECIMAL_TYPE(b, size) => self.visit_decimal(b, size), + }) + } + fn visit_decimal(&mut self, column: Buffer, _size: DecimalSize) -> Result<()> { self.visit_typed_column::>(column) } @@ -113,16 +125,8 @@ pub trait ValueVisitor { Column::Null { len } => self.visit_null(len), Column::EmptyArray { len } => self.visit_empty_array(len), Column::EmptyMap { len } => self.visit_empty_map(len), - Column::Number(column) => { - with_number_type!(|NUM_TYPE| match column { - NumberColumn::NUM_TYPE(b) => self.visit_number(b), - }) - } - Column::Decimal(column) => { - with_decimal_type!(|DECIMAL_TYPE| match column { - DecimalColumn::DECIMAL_TYPE(b, size) => self.visit_decimal(b, size), - }) - } + Column::Number(column) => self.visit_any_number(column), + Column::Decimal(column) => self.visit_any_decimal(column), Column::Boolean(bitmap) => self.visit_boolean(bitmap), Column::Binary(column) => self.visit_binary(column), Column::String(column) => self.visit_string(column), diff --git a/src/query/pipeline/core/src/processors/shuffle_processor.rs b/src/query/pipeline/core/src/processors/shuffle_processor.rs index 6d709988a09f7..724e76bc3839b 100644 --- a/src/query/pipeline/core/src/processors/shuffle_processor.rs +++ b/src/query/pipeline/core/src/processors/shuffle_processor.rs @@ -31,6 +31,7 @@ pub enum MultiwayStrategy { } pub trait Exchange: Send + Sync + 'static { + const NAME: &'static str; const STRATEGY: MultiwayStrategy = MultiwayStrategy::Random; fn partition(&self, data_block: DataBlock, n: usize) -> Result>; @@ -185,7 +186,7 @@ impl PartitionProcessor { impl Processor for PartitionProcessor { fn name(&self) -> String { - String::from("ShufflePartition") + format!("ShufflePartition({})", T::NAME) } fn as_any(&mut self) -> &mut dyn Any { @@ -287,7 +288,7 @@ impl MergePartitionProcessor { impl Processor for MergePartitionProcessor { fn name(&self) -> String { - String::from("ShuffleMergePartition") + format!("ShuffleMergePartition({})", T::NAME) } fn as_any(&mut self) -> &mut dyn Any { diff --git a/src/query/service/src/pipelines/builders/builder_window.rs b/src/query/service/src/pipelines/builders/builder_window.rs index 252bd50b592d0..4205eb35cc564 100644 --- a/src/query/service/src/pipelines/builders/builder_window.rs +++ b/src/query/service/src/pipelines/builders/builder_window.rs @@ -34,6 +34,7 @@ use crate::pipelines::processors::transforms::TransformWindow; use crate::pipelines::processors::transforms::TransformWindowPartitionCollect; use crate::pipelines::processors::transforms::WindowFunctionInfo; use crate::pipelines::processors::transforms::WindowPartitionExchange; +use crate::pipelines::processors::transforms::WindowPartitionTopNExchange; use crate::pipelines::processors::transforms::WindowSortDesc; use crate::pipelines::processors::transforms::WindowSpillSettings; use crate::pipelines::PipelineBuilder; @@ -169,10 +170,23 @@ impl PipelineBuilder { }) .collect::>>()?; - self.main_pipeline.exchange( - num_processors, - WindowPartitionExchange::create(partition_by.clone(), num_partitions), - ); + if let Some(top_n) = &window_partition.top_n { + self.main_pipeline.exchange( + num_processors, + WindowPartitionTopNExchange::create( + partition_by.clone(), + sort_desc.clone(), + top_n.top, + top_n.func, + num_partitions as u64, + ), + ) + } else { + self.main_pipeline.exchange( + num_processors, + WindowPartitionExchange::create(partition_by.clone(), num_partitions), + ); + } let disk_bytes_limit = settings.get_window_partition_spilling_to_disk_bytes_limit()?; let temp_dir_manager = TempDirManager::instance(); diff --git a/src/query/service/src/pipelines/processors/transforms/window/partition/mod.rs b/src/query/service/src/pipelines/processors/transforms/window/partition/mod.rs index b75ff9d671cf6..3eebde7955a8b 100644 --- a/src/query/service/src/pipelines/processors/transforms/window/partition/mod.rs +++ b/src/query/service/src/pipelines/processors/transforms/window/partition/mod.rs @@ -16,8 +16,10 @@ mod transform_window_partition_collect; mod window_partition_buffer; mod window_partition_exchange; mod window_partition_meta; +mod window_partition_partial_top_n_exchange; pub use transform_window_partition_collect::*; pub use window_partition_buffer::*; pub use window_partition_exchange::*; pub use window_partition_meta::*; +pub use window_partition_partial_top_n_exchange::*; diff --git a/src/query/service/src/pipelines/processors/transforms/window/partition/window_partition_exchange.rs b/src/query/service/src/pipelines/processors/transforms/window/partition/window_partition_exchange.rs index a23e6b030c61c..bf6ea988acf65 100644 --- a/src/query/service/src/pipelines/processors/transforms/window/partition/window_partition_exchange.rs +++ b/src/query/service/src/pipelines/processors/transforms/window/partition/window_partition_exchange.rs @@ -15,10 +15,9 @@ use std::sync::Arc; use databend_common_exception::Result; -use databend_common_expression::group_hash_columns_slice; -use databend_common_expression::ColumnBuilder; +use databend_common_expression::group_hash_columns; use databend_common_expression::DataBlock; -use databend_common_expression::Value; +use databend_common_expression::InputColumns; use databend_common_pipeline_core::processors::Exchange; use super::WindowPartitionMeta; @@ -38,27 +37,17 @@ impl WindowPartitionExchange { } impl Exchange for WindowPartitionExchange { + const NAME: &'static str = "Window"; fn partition(&self, data_block: DataBlock, n: usize) -> Result> { let num_rows = data_block.num_rows(); // Extract the columns used for hash computation. - let hash_cols = self - .hash_keys - .iter() - .map(|&offset| { - let entry = data_block.get_by_offset(offset); - match &entry.value { - Value::Scalar(s) => { - ColumnBuilder::repeat(&s.as_ref(), num_rows, &entry.data_type).build() - } - Value::Column(c) => c.clone(), - } - }) - .collect::>(); + let data_block = data_block.consume_convert_to_full(); + let hash_cols = InputColumns::new_block_proxy(&self.hash_keys, &data_block); // Compute the hash value for each row. let mut hashes = vec![0u64; num_rows]; - group_hash_columns_slice(&hash_cols, &mut hashes); + group_hash_columns(hash_cols, &mut hashes); // Scatter the data block to different partitions. let indices = hashes diff --git a/src/query/service/src/pipelines/processors/transforms/window/partition/window_partition_partial_top_n_exchange.rs b/src/query/service/src/pipelines/processors/transforms/window/partition/window_partition_partial_top_n_exchange.rs new file mode 100644 index 0000000000000..8601895dcffd3 --- /dev/null +++ b/src/query/service/src/pipelines/processors/transforms/window/partition/window_partition_partial_top_n_exchange.rs @@ -0,0 +1,336 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use databend_common_exception::Result; +use databend_common_expression::group_hash_value_spread; +use databend_common_expression::visitor::ValueVisitor; +use databend_common_expression::DataBlock; +use databend_common_expression::SortColumnDescription; +use databend_common_expression::SortCompare; +use databend_common_pipeline_core::processors::Exchange; + +use super::WindowPartitionMeta; +use crate::sql::executor::physical_plans::WindowPartitionTopNFunc; + +pub struct WindowPartitionTopNExchange { + partition_indices: Box<[usize]>, + top: usize, + func: WindowPartitionTopNFunc, + + sort_desc: Box<[SortColumnDescription]>, + num_partitions: u64, +} + +impl WindowPartitionTopNExchange { + pub fn create( + partition_indices: Vec, + order_by: Vec, + top: usize, + func: WindowPartitionTopNFunc, + num_partitions: u64, + ) -> Arc { + assert!(top > 0); + let partition_indices = partition_indices.into_boxed_slice(); + let sort_desc = partition_indices + .iter() + .map(|&offset| SortColumnDescription { + offset, + asc: true, + nulls_first: false, + }) + .chain(order_by) + .collect::>() + .into(); + + Arc::new(WindowPartitionTopNExchange { + num_partitions, + partition_indices, + top, + func, + sort_desc, + }) + } +} + +impl Exchange for WindowPartitionTopNExchange { + const NAME: &'static str = "WindowTopN"; + fn partition(&self, block: DataBlock, n: usize) -> Result> { + let partition_permutation = self.partition_permutation(&block); + + // Partition the data blocks to different processors. + let mut output_data_blocks = vec![vec![]; n]; + for (partition_id, indices) in partition_permutation.into_iter().enumerate() { + output_data_blocks[partition_id % n].push((partition_id, block.take(&indices)?)); + } + + // Union data blocks for each processor. + Ok(output_data_blocks + .into_iter() + .map(WindowPartitionMeta::create) + .map(DataBlock::empty_with_meta) + .collect()) + } +} + +impl WindowPartitionTopNExchange { + fn partition_permutation(&self, block: &DataBlock) -> Vec> { + let rows = block.num_rows(); + let mut sort_compare = SortCompare::with_force_equality(self.sort_desc.to_vec(), rows); + + for &offset in &self.partition_indices { + let array = block.get_by_offset(offset).value.clone(); + sort_compare.visit_value(array).unwrap(); + sort_compare.increment_column_index(); + } + + let partition_equality = sort_compare.equality_index().to_vec(); + + for desc in self.sort_desc.iter().skip(self.partition_indices.len()) { + let array = block.get_by_offset(desc.offset).value.clone(); + sort_compare.visit_value(array).unwrap(); + sort_compare.increment_column_index(); + } + + let full_equality = sort_compare.equality_index().to_vec(); + let permutation = sort_compare.take_permutation(); + + let hash_indices = std::iter::once(permutation[0]) + .chain( + partition_equality + .iter() + .enumerate() + .filter_map(|(i, &eq)| if eq == 0 { Some(permutation[i]) } else { None }), + ) + .collect::>(); + + let mut hashes = vec![0u64; rows]; + for (i, &offset) in self.partition_indices.iter().enumerate() { + let entry = block.get_by_offset(offset); + group_hash_value_spread(&hash_indices, entry.value.to_owned(), i == 0, &mut hashes) + .unwrap(); + } + + let mut partition_permutation = vec![Vec::new(); self.num_partitions as usize]; + + let mut start = 0; + let mut cur = 0; + while cur < rows { + let partition = &mut partition_permutation + [(hashes[permutation[start] as usize] % self.num_partitions) as usize]; + partition.push(permutation[start]); + + let mut rank = 0; // this first value is rank 0 + cur = start + 1; + while cur < rows { + if partition_equality[cur] == 0 { + start = cur; + break; + } + + match self.func { + WindowPartitionTopNFunc::RowNumber => { + if cur - start < self.top { + partition.push(permutation[cur]); + } + } + WindowPartitionTopNFunc::Rank | WindowPartitionTopNFunc::DenseRank => { + if full_equality[cur] == 0 { + if matches!(self.func, WindowPartitionTopNFunc::Rank) { + rank = cur - start + } else { + rank += 1 + } + } + + if rank < self.top { + partition.push(permutation[cur]); + } + } + } + cur += 1; + } + } + partition_permutation + } +} + +#[cfg(test)] +mod tests { + use databend_common_expression::types::ArgType; + use databend_common_expression::types::Int32Type; + use databend_common_expression::types::StringType; + use databend_common_expression::BlockEntry; + use databend_common_expression::FromData; + use databend_common_expression::Scalar; + use databend_common_expression::Value; + + use super::*; + + #[test] + fn test_row_number() -> Result<()> { + let p = WindowPartitionTopNExchange::create( + vec![1, 2], + vec![SortColumnDescription { + offset: 0, + asc: true, + nulls_first: false, + }], + 3, + WindowPartitionTopNFunc::RowNumber, + 8, + ); + + let data = DataBlock::new( + vec![ + BlockEntry::new( + Int32Type::data_type(), + Value::Column(Int32Type::from_data(vec![3, 1, 2, 2, 4, 3, 7, 0, 3])), + ), + BlockEntry::new( + StringType::data_type(), + Value::Scalar(Scalar::String("a".to_string())), + ), + BlockEntry::new( + Int32Type::data_type(), + Value::Column(Int32Type::from_data(vec![3, 1, 3, 2, 2, 3, 4, 3, 3])), + ), + BlockEntry::new( + StringType::data_type(), + Value::Column(StringType::from_data(vec![ + "a", "b", "c", "d", "e", "f", "g", "h", "i", + ])), + ), + ], + 9, + ); + data.check_valid()?; + + let got = p.partition_permutation(&data); + + let want = vec![ + vec![], + vec![1], + vec![], + vec![3, 4], + vec![], + vec![6], + vec![], + vec![7, 2, 0], + ]; + // if got != want { + // let got = got + // .iter() + // .map(|indices| data.take(indices, &mut None).unwrap()) + // .collect::>(); + // for x in got { + // println!("{}", x) + // } + // } + assert_eq!(&want, &got); + + Ok(()) + } + + #[test] + fn test_rank() -> Result<()> { + let p = WindowPartitionTopNExchange::create( + vec![1], + vec![SortColumnDescription { + offset: 0, + asc: true, + nulls_first: false, + }], + 3, + WindowPartitionTopNFunc::Rank, + 8, + ); + + let data = DataBlock::new( + vec![ + BlockEntry::new( + Int32Type::data_type(), + Value::Column(Int32Type::from_data(vec![7, 7, 7, 6, 5, 5, 4, 1, 3, 1, 1])), + ), + BlockEntry::new( + Int32Type::data_type(), + Value::Column(Int32Type::from_data(vec![7, 6, 5, 5, 5, 4, 3, 3, 2, 3, 3])), + ), + ], + 11, + ); + data.check_valid()?; + + let got = p.partition_permutation(&data); + + let want = vec![ + vec![], + vec![1], + vec![8, 0], + vec![], + vec![5, 4, 3, 2], + vec![], + vec![7, 9, 10], + vec![], + ]; + assert_eq!(&want, &got); + Ok(()) + } + + #[test] + fn test_dense_rank() -> Result<()> { + let p = WindowPartitionTopNExchange::create( + vec![1], + vec![SortColumnDescription { + offset: 0, + asc: true, + nulls_first: false, + }], + 3, + WindowPartitionTopNFunc::DenseRank, + 8, + ); + + let data = DataBlock::new( + vec![ + BlockEntry::new( + Int32Type::data_type(), + Value::Column(Int32Type::from_data(vec![5, 2, 3, 3, 2, 2, 1, 1, 1, 1, 1])), + ), + BlockEntry::new( + Int32Type::data_type(), + Value::Column(Int32Type::from_data(vec![2, 2, 4, 3, 2, 2, 5, 4, 3, 3, 3])), + ), + ], + 11, + ); + data.check_valid()?; + + let got = p.partition_permutation(&data); + + let want = vec![ + vec![], + vec![], + vec![1, 4, 5, 0], + vec![], + vec![7, 2, 6], + vec![], + vec![8, 9, 10, 3], + vec![], + ]; + assert_eq!(&want, &got); + Ok(()) + } +} diff --git a/src/query/sql/src/executor/physical_plan_visitor.rs b/src/query/sql/src/executor/physical_plan_visitor.rs index 6dcab582c7f34..dd86b4cda858a 100644 --- a/src/query/sql/src/executor/physical_plan_visitor.rs +++ b/src/query/sql/src/executor/physical_plan_visitor.rs @@ -242,6 +242,7 @@ pub trait PhysicalPlanReplacer { partition_by: plan.partition_by.clone(), order_by: plan.order_by.clone(), after_exchange: plan.after_exchange, + top_n: plan.top_n.clone(), stat_info: plan.stat_info.clone(), })) } diff --git a/src/query/sql/src/executor/physical_plans/mod.rs b/src/query/sql/src/executor/physical_plans/mod.rs index f8b1c86702d72..957443396364a 100644 --- a/src/query/sql/src/executor/physical_plans/mod.rs +++ b/src/query/sql/src/executor/physical_plans/mod.rs @@ -57,6 +57,7 @@ mod physical_table_scan; mod physical_udf; mod physical_union_all; mod physical_window; +mod physical_window_partition; pub use common::*; pub use physical_add_stream_column::AddStreamColumn; @@ -101,10 +102,9 @@ pub use physical_replace_deduplicate::*; pub use physical_replace_into::ReplaceInto; pub use physical_row_fetch::RowFetch; pub use physical_sort::Sort; -mod physical_window_partition; pub use physical_table_scan::TableScan; pub use physical_udf::Udf; pub use physical_udf::UdfFunctionDesc; pub use physical_union_all::UnionAll; pub use physical_window::*; -pub use physical_window_partition::WindowPartition; +pub use physical_window_partition::*; diff --git a/src/query/sql/src/executor/physical_plans/physical_sort.rs b/src/query/sql/src/executor/physical_plans/physical_sort.rs index c535a33a2fccd..8bfc63719f2dc 100644 --- a/src/query/sql/src/executor/physical_plans/physical_sort.rs +++ b/src/query/sql/src/executor/physical_plans/physical_sort.rs @@ -24,9 +24,12 @@ use itertools::Itertools; use crate::executor::explain::PlanStatsInfo; use crate::executor::physical_plans::common::SortDesc; use crate::executor::physical_plans::WindowPartition; +use crate::executor::physical_plans::WindowPartitionTopN; +use crate::executor::physical_plans::WindowPartitionTopNFunc; use crate::executor::PhysicalPlan; use crate::executor::PhysicalPlanBuilder; use crate::optimizer::SExpr; +use crate::plans::WindowFuncType; use crate::ColumnSet; use crate::IndexType; @@ -122,12 +125,6 @@ impl PhysicalPlanBuilder { let input_plan = self.build(s_expr.child(0)?, required).await?; - let window_partition = sort - .window_partition - .iter() - .map(|v| v.index) - .collect::>(); - let order_by = sort .items .iter() @@ -140,16 +137,31 @@ impl PhysicalPlanBuilder { .collect::>(); // Add WindowPartition for parallel sort in window. - if !window_partition.is_empty() { + if let Some(window) = &sort.window_partition { + let window_partition = window + .partition_by + .iter() + .map(|v| v.index) + .collect::>(); + return Ok(PhysicalPlan::WindowPartition(WindowPartition { plan_id: 0, input: Box::new(input_plan.clone()), partition_by: window_partition.clone(), order_by: order_by.clone(), after_exchange: sort.after_exchange, + top_n: window.top.map(|top| WindowPartitionTopN { + func: match window.func { + WindowFuncType::RowNumber => WindowPartitionTopNFunc::RowNumber, + WindowFuncType::Rank => WindowPartitionTopNFunc::Rank, + WindowFuncType::DenseRank => WindowPartitionTopNFunc::DenseRank, + _ => unreachable!(), + }, + top, + }), stat_info: Some(stat_info.clone()), })); - } + }; // 2. Build physical plan. Ok(PhysicalPlan::Sort(Sort { diff --git a/src/query/sql/src/executor/physical_plans/physical_window_partition.rs b/src/query/sql/src/executor/physical_plans/physical_window_partition.rs index 002daa1955f00..b0ff12d3f8685 100644 --- a/src/query/sql/src/executor/physical_plans/physical_window_partition.rs +++ b/src/query/sql/src/executor/physical_plans/physical_window_partition.rs @@ -27,10 +27,24 @@ pub struct WindowPartition { pub partition_by: Vec, pub order_by: Vec, pub after_exchange: Option, + pub top_n: Option, pub stat_info: Option, } +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct WindowPartitionTopN { + pub func: WindowPartitionTopNFunc, + pub top: usize, +} + +#[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize)] +pub enum WindowPartitionTopNFunc { + RowNumber, + Rank, + DenseRank, +} + impl WindowPartition { pub fn output_schema(&self) -> Result { self.input.output_schema() diff --git a/src/query/sql/src/planner/binder/bind_query/bind.rs b/src/query/sql/src/planner/binder/bind_query/bind.rs index e2c0b565f3774..f3556bc021707 100644 --- a/src/query/sql/src/planner/binder/bind_query/bind.rs +++ b/src/query/sql/src/planner/binder/bind_query/bind.rs @@ -159,7 +159,7 @@ impl Binder { limit: None, after_exchange: None, pre_projection: None, - window_partition: vec![], + window_partition: None, }; Ok(SExpr::create_unary( Arc::new(sort_plan.into()), diff --git a/src/query/sql/src/planner/binder/sort.rs b/src/query/sql/src/planner/binder/sort.rs index 5deb9ed8c1a67..c3d630fce11b4 100644 --- a/src/query/sql/src/planner/binder/sort.rs +++ b/src/query/sql/src/planner/binder/sort.rs @@ -217,7 +217,7 @@ impl Binder { limit: None, after_exchange: None, pre_projection: None, - window_partition: vec![], + window_partition: None, }; let new_expr = SExpr::create_unary(Arc::new(sort_plan.into()), Arc::new(child)); Ok(new_expr) diff --git a/src/query/sql/src/planner/binder/window.rs b/src/query/sql/src/planner/binder/window.rs index f10a088859d0f..e43511df31b40 100644 --- a/src/query/sql/src/planner/binder/window.rs +++ b/src/query/sql/src/planner/binder/window.rs @@ -42,6 +42,7 @@ use crate::plans::WindowFunc; use crate::plans::WindowFuncFrame; use crate::plans::WindowFuncType; use crate::plans::WindowOrderBy; +use crate::plans::WindowPartition; use crate::BindContext; use crate::Binder; use crate::ColumnEntry; @@ -116,10 +117,18 @@ impl Binder { let child = if !sort_items.is_empty() { let sort_plan = Sort { items: sort_items, - limit: window_plan.limit, + limit: None, after_exchange: None, pre_projection: None, - window_partition: window_plan.partition_by.clone(), + window_partition: if window_plan.partition_by.is_empty() { + None + } else { + Some(WindowPartition { + partition_by: window_plan.partition_by.clone(), + top: None, + func: window_plan.function.clone(), + }) + }, }; SExpr::create_unary(Arc::new(sort_plan.into()), Arc::new(child)) } else { diff --git a/src/query/sql/src/planner/format/display_rel_operator.rs b/src/query/sql/src/planner/format/display_rel_operator.rs index f4d236f9c249b..23e7c43291323 100644 --- a/src/query/sql/src/planner/format/display_rel_operator.rs +++ b/src/query/sql/src/planner/format/display_rel_operator.rs @@ -385,10 +385,23 @@ fn sort_to_format_tree .join(", "); let limit = op.limit.map_or("NONE".to_string(), |l| l.to_string()); - FormatTreeNode::with_children("Sort".to_string(), vec![ - FormatTreeNode::new(format!("sort keys: [{}]", scalars)), - FormatTreeNode::new(format!("limit: [{}]", limit)), - ]) + let children = match &op.window_partition { + Some(window) => vec![ + FormatTreeNode::new(format!("sort keys: [{}]", scalars)), + FormatTreeNode::new(format!("limit: [{}]", limit)), + FormatTreeNode::new(format!( + "window top: {}", + window.top.map_or("NONE".to_string(), |n| n.to_string()) + )), + FormatTreeNode::new(format!("window function: {:?}", window.func)), + ], + None => vec![ + FormatTreeNode::new(format!("sort keys: [{}]", scalars)), + FormatTreeNode::new(format!("limit: [{}]", limit)), + ], + }; + + FormatTreeNode::with_children("Sort".to_string(), children) } fn constant_scan_to_format_tree>( diff --git a/src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs b/src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs index d973b91a69447..5a2b34ceec354 100644 --- a/src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs +++ b/src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs @@ -162,10 +162,13 @@ impl SubqueryRewriter { RelOperator::Sort(mut sort) => { let mut input = self.rewrite(s_expr.child(0)?)?; - for item in sort.window_partition.iter_mut() { - let res = self.try_rewrite_subquery(&item.scalar, &input, false)?; - input = res.1; - item.scalar = res.0; + + if let Some(window) = &mut sort.window_partition { + for item in window.partition_by.iter_mut() { + let res = self.try_rewrite_subquery(&item.scalar, &input, false)?; + input = res.1; + item.scalar = res.0; + } } Ok(SExpr::create_unary(Arc::new(sort.into()), Arc::new(input))) diff --git a/src/query/sql/src/planner/optimizer/rule/factory.rs b/src/query/sql/src/planner/optimizer/rule/factory.rs index 98ece251cf6b2..50ab1ed44ef99 100644 --- a/src/query/sql/src/planner/optimizer/rule/factory.rs +++ b/src/query/sql/src/planner/optimizer/rule/factory.rs @@ -16,41 +16,42 @@ use databend_common_exception::Result; use super::rewrite::RuleCommuteJoin; use super::rewrite::RuleEliminateEvalScalar; +use super::rewrite::RuleEliminateFilter; +use super::rewrite::RuleEliminateSort; use super::rewrite::RuleEliminateUnion; +use super::rewrite::RuleFilterNulls; use super::rewrite::RuleFoldCountAggregate; +use super::rewrite::RuleMergeEvalScalar; +use super::rewrite::RuleMergeFilter; use super::rewrite::RuleNormalizeScalarFilter; use super::rewrite::RulePushDownFilterAggregate; use super::rewrite::RulePushDownFilterEvalScalar; use super::rewrite::RulePushDownFilterJoin; +use super::rewrite::RulePushDownFilterProjectSet; +use super::rewrite::RulePushDownFilterScan; +use super::rewrite::RulePushDownFilterSort; +use super::rewrite::RulePushDownFilterUnion; use super::rewrite::RulePushDownFilterWindow; +use super::rewrite::RulePushDownFilterWindowTopN; +use super::rewrite::RulePushDownLimit; use super::rewrite::RulePushDownLimitEvalScalar; +use super::rewrite::RulePushDownLimitOuterJoin; +use super::rewrite::RulePushDownLimitScan; +use super::rewrite::RulePushDownLimitSort; +use super::rewrite::RulePushDownLimitUnion; +use super::rewrite::RulePushDownLimitWindow; use super::rewrite::RulePushDownPrewhere; use super::rewrite::RulePushDownRankLimitAggregate; use super::rewrite::RulePushDownSortEvalScalar; +use super::rewrite::RulePushDownSortScan; +use super::rewrite::RuleSemiToInnerJoin; +use super::rewrite::RuleSplitAggregate; use super::rewrite::RuleTryApplyAggIndex; -use crate::optimizer::rule::rewrite::RuleEliminateFilter; -use crate::optimizer::rule::rewrite::RuleEliminateSort; -use crate::optimizer::rule::rewrite::RuleFilterNulls; -use crate::optimizer::rule::rewrite::RuleMergeEvalScalar; -use crate::optimizer::rule::rewrite::RuleMergeFilter; -use crate::optimizer::rule::rewrite::RulePushDownFilterProjectSet; -use crate::optimizer::rule::rewrite::RulePushDownFilterScan; -use crate::optimizer::rule::rewrite::RulePushDownFilterSort; -use crate::optimizer::rule::rewrite::RulePushDownFilterUnion; -use crate::optimizer::rule::rewrite::RulePushDownLimit; -use crate::optimizer::rule::rewrite::RulePushDownLimitOuterJoin; -use crate::optimizer::rule::rewrite::RulePushDownLimitScan; -use crate::optimizer::rule::rewrite::RulePushDownLimitSort; -use crate::optimizer::rule::rewrite::RulePushDownLimitUnion; -use crate::optimizer::rule::rewrite::RulePushDownLimitWindow; -use crate::optimizer::rule::rewrite::RulePushDownSortScan; -use crate::optimizer::rule::rewrite::RuleSemiToInnerJoin; -use crate::optimizer::rule::rewrite::RuleSplitAggregate; -use crate::optimizer::rule::transform::RuleCommuteJoinBaseTable; -use crate::optimizer::rule::transform::RuleEagerAggregation; -use crate::optimizer::rule::transform::RuleLeftExchangeJoin; -use crate::optimizer::rule::RuleID; -use crate::optimizer::rule::RulePtr; +use super::transform::RuleCommuteJoinBaseTable; +use super::transform::RuleEagerAggregation; +use super::transform::RuleLeftExchangeJoin; +use super::RuleID; +use super::RulePtr; use crate::optimizer::OptimizerContext; pub struct RuleFactory; @@ -91,6 +92,7 @@ impl RuleFactory { } RuleID::PushDownFilterAggregate => Ok(Box::new(RulePushDownFilterAggregate::new())), RuleID::PushDownFilterWindow => Ok(Box::new(RulePushDownFilterWindow::new())), + RuleID::PushDownFilterWindowRank => Ok(Box::new(RulePushDownFilterWindowTopN::new())), RuleID::EliminateFilter => Ok(Box::new(RuleEliminateFilter::new(ctx.metadata))), RuleID::MergeEvalScalar => Ok(Box::new(RuleMergeEvalScalar::new())), RuleID::MergeFilter => Ok(Box::new(RuleMergeFilter::new())), diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/mod.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/mod.rs index d8c306a88bfab..c564477e2f32f 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/mod.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/mod.rs @@ -32,6 +32,7 @@ mod rule_push_down_filter_scan; mod rule_push_down_filter_sort; mod rule_push_down_filter_union; mod rule_push_down_filter_window; +mod rule_push_down_filter_window_top_n; mod rule_push_down_limit; mod rule_push_down_limit_aggregate; mod rule_push_down_limit_expression; @@ -66,6 +67,7 @@ pub use rule_push_down_filter_scan::RulePushDownFilterScan; pub use rule_push_down_filter_sort::RulePushDownFilterSort; pub use rule_push_down_filter_union::RulePushDownFilterUnion; pub use rule_push_down_filter_window::RulePushDownFilterWindow; +pub use rule_push_down_filter_window_top_n::RulePushDownFilterWindowTopN; pub use rule_push_down_limit::RulePushDownLimit; pub use rule_push_down_limit_aggregate::RulePushDownRankLimitAggregate; pub use rule_push_down_limit_expression::RulePushDownLimitEvalScalar; diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_eliminate_sort.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_eliminate_sort.rs index 6201891de8f5b..4aacb621359d8 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_eliminate_sort.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_eliminate_sort.rs @@ -55,14 +55,14 @@ impl Rule for RuleEliminateSort { let rel_expr = RelExpr::with_s_expr(input); let prop = rel_expr.derive_relational_prop()?; - if !sort.window_partition.is_empty() { + if let Some(window) = &sort.window_partition { if let Some((partition, ordering)) = &prop.partition_orderings { // must has same partition // if the ordering of the current node is empty, we can eliminate the sort // eg: explain select number, sum(number - 1) over (partition by number % 3 order by number + 1), // avg(number) over (partition by number % 3 order by number + 1) // from numbers(50); - if partition == &sort.window_partition + if partition == &window.partition_by && (ordering == &sort.items || sort.sort_items_exclude_partition().is_empty()) { state.add_result(input.clone()); diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_window_top_n.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_window_top_n.rs new file mode 100644 index 0000000000000..136aa8622076a --- /dev/null +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_window_top_n.rs @@ -0,0 +1,185 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use databend_common_exception::Result; +use databend_common_expression::type_check::check_number; +use databend_common_expression::FunctionContext; +use databend_common_functions::BUILTIN_FUNCTIONS; + +use crate::optimizer::extract::Matcher; +use crate::optimizer::rule::Rule; +use crate::optimizer::rule::TransformResult; +use crate::optimizer::RuleID; +use crate::optimizer::SExpr; +use crate::plans::ComparisonOp; +use crate::plans::Filter; +use crate::plans::RelOp; +use crate::plans::ScalarExpr; +use crate::plans::Sort; +use crate::plans::Window; +use crate::plans::WindowFuncType; + +/// Input: Filter +/// \ +/// Window +/// \ +/// Sort +/// +/// Output: Filter +/// \ +/// Window +/// \ +/// Sort(top n) +pub struct RulePushDownFilterWindowTopN { + id: RuleID, + matchers: Vec, +} + +impl RulePushDownFilterWindowTopN { + pub fn new() -> Self { + Self { + id: RuleID::PushDownFilterWindowRank, + matchers: vec![Matcher::MatchOp { + op_type: RelOp::Filter, + children: vec![Matcher::MatchOp { + op_type: RelOp::Window, + children: vec![Matcher::MatchOp { + op_type: RelOp::Sort, + children: vec![Matcher::Leaf], + }], + }], + }], + } + } +} + +impl Rule for RulePushDownFilterWindowTopN { + fn id(&self) -> RuleID { + self.id + } + + fn apply(&self, s_expr: &SExpr, state: &mut TransformResult) -> Result<()> { + let filter: Filter = s_expr.plan().clone().try_into()?; + let window_expr = s_expr.child(0)?; + let window: Window = window_expr.plan().clone().try_into()?; + let sort_expr = window_expr.child(0)?; + let mut sort: Sort = sort_expr.plan().clone().try_into()?; + + if !is_ranking_function(&window.function) || sort.window_partition.is_none() { + return Ok(()); + } + + let predicates = filter + .predicates + .into_iter() + .filter_map(|predicate| extract_top_n(window.index, predicate)) + .collect::>(); + + let Some(top_n) = predicates.into_iter().min() else { + return Ok(()); + }; + + if top_n == 0 { + // TODO + return Ok(()); + } + + sort.window_partition.as_mut().unwrap().top = Some(top_n); + + let mut result = SExpr::create_unary( + s_expr.plan.clone(), + SExpr::create_unary( + window_expr.plan.clone(), + sort_expr.replace_plan(Arc::new(sort.into())).into(), + ) + .into(), + ); + result.set_applied_rule(&self.id); + + state.add_result(result); + + Ok(()) + } + + fn matchers(&self) -> &[Matcher] { + &self.matchers + } +} + +fn extract_top_n(column: usize, predicate: ScalarExpr) -> Option { + let ScalarExpr::FunctionCall(call) = predicate else { + return None; + }; + + let func_name = &call.func_name; + if func_name == ComparisonOp::Equal.to_func_name() { + return match (&call.arguments[0], &call.arguments[1]) { + (ScalarExpr::BoundColumnRef(col), number) + | (number, ScalarExpr::BoundColumnRef(col)) + if col.column.index == column => + { + extract_i32(number).map(|n| n.max(0) as usize) + } + _ => None, + }; + } + + let (left, right) = match ( + func_name == ComparisonOp::LTE.to_func_name() + || func_name == ComparisonOp::LT.to_func_name(), + func_name == ComparisonOp::GTE.to_func_name() + || func_name == ComparisonOp::GT.to_func_name(), + ) { + (true, _) => (&call.arguments[0], &call.arguments[1]), + (_, true) => (&call.arguments[1], &call.arguments[0]), + _ => return None, + }; + + let ScalarExpr::BoundColumnRef(col) = left else { + return None; + }; + if col.column.index != column { + return None; + } + + let eq = func_name == ComparisonOp::GTE.to_func_name() + || func_name == ComparisonOp::LTE.to_func_name(); + + extract_i32(right).map(|n| { + if eq { + n.max(0) as usize + } else { + n.max(1) as usize - 1 + } + }) +} + +fn extract_i32(expr: &ScalarExpr) -> Option { + check_number( + None, + &FunctionContext::default(), + &expr.as_expr().ok()?, + &BUILTIN_FUNCTIONS, + ) + .ok() +} + +fn is_ranking_function(func: &WindowFuncType) -> bool { + matches!( + func, + WindowFuncType::RowNumber | WindowFuncType::Rank | WindowFuncType::DenseRank + ) +} diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_limit_aggregate.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_limit_aggregate.rs index ad1cae65cf992..aaf241cd5e63f 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_limit_aggregate.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_limit_aggregate.rs @@ -105,7 +105,7 @@ impl RulePushDownRankLimitAggregate { limit: Some(count), after_exchange: None, pre_projection: None, - window_partition: vec![], + window_partition: None, }; let agg = SExpr::create_unary( diff --git a/src/query/sql/src/planner/optimizer/rule/rule.rs b/src/query/sql/src/planner/optimizer/rule/rule.rs index b765aba22c820..f13a67a6b5a54 100644 --- a/src/query/sql/src/planner/optimizer/rule/rule.rs +++ b/src/query/sql/src/planner/optimizer/rule/rule.rs @@ -37,6 +37,7 @@ pub static DEFAULT_REWRITE_RULES: LazyLock> = LazyLock::new(|| { RuleID::PushDownFilterUnion, RuleID::PushDownFilterAggregate, RuleID::PushDownFilterWindow, + RuleID::PushDownFilterWindowRank, RuleID::PushDownFilterSort, RuleID::PushDownFilterEvalScalar, RuleID::PushDownFilterJoin, @@ -90,6 +91,7 @@ pub enum RuleID { PushDownFilterSort, PushDownFilterProjectSet, PushDownFilterWindow, + PushDownFilterWindowRank, PushDownLimit, PushDownLimitUnion, PushDownLimitOuterJoin, @@ -140,6 +142,7 @@ impl Display for RuleID { RuleID::PushDownSortEvalScalar => write!(f, "PushDownSortEvalScalar"), RuleID::PushDownLimitWindow => write!(f, "PushDownLimitWindow"), RuleID::PushDownFilterWindow => write!(f, "PushDownFilterWindow"), + RuleID::PushDownFilterWindowRank => write!(f, "PushDownFilterWindowRank"), RuleID::EliminateEvalScalar => write!(f, "EliminateEvalScalar"), RuleID::EliminateFilter => write!(f, "EliminateFilter"), RuleID::EliminateSort => write!(f, "EliminateSort"), diff --git a/src/query/sql/src/planner/plans/sort.rs b/src/query/sql/src/planner/plans/sort.rs index cc31a7ed5fbef..1f7a1536f2878 100644 --- a/src/query/sql/src/planner/plans/sort.rs +++ b/src/query/sql/src/planner/plans/sort.rs @@ -17,6 +17,7 @@ use std::sync::Arc; use databend_common_catalog::table_context::TableContext; use databend_common_exception::Result; +use super::WindowPartition; use crate::optimizer::Distribution; use crate::optimizer::PhysicalProperty; use crate::optimizer::RelExpr; @@ -25,7 +26,6 @@ use crate::optimizer::RequiredProperty; use crate::optimizer::StatInfo; use crate::plans::Operator; use crate::plans::RelOp; -use crate::plans::ScalarItem; use crate::ColumnSet; use crate::IndexType; @@ -43,7 +43,7 @@ pub struct Sort { pub pre_projection: Option>, /// If sort is for window clause, we need the input to exchange by partitions - pub window_partition: Vec, + pub window_partition: Option, } impl Sort { @@ -54,11 +54,12 @@ impl Sort { pub fn sort_items_exclude_partition(&self) -> Vec { self.items .iter() - .filter(|item| { - !self - .window_partition + .filter(|item| match &self.window_partition { + Some(window) => !window + .partition_by .iter() - .any(|partition| partition.index == item.index) + .any(|partition| partition.index == item.index), + None => true, }) .cloned() .collect() @@ -79,14 +80,15 @@ impl Operator for Sort { fn derive_physical_prop(&self, rel_expr: &RelExpr) -> Result { let input_physical_prop = rel_expr.derive_physical_prop_child(0)?; - if input_physical_prop.distribution == Distribution::Serial - || self.window_partition.is_empty() - { + if input_physical_prop.distribution == Distribution::Serial { return Ok(input_physical_prop); } + let Some(window) = &self.window_partition else { + return Ok(input_physical_prop); + }; - let partition_by = self - .window_partition + let partition_by = window + .partition_by .iter() .map(|s| s.scalar.clone()) .collect(); @@ -105,9 +107,9 @@ impl Operator for Sort { let mut required = required.clone(); required.distribution = Distribution::Serial; - if self.window_partition.is_empty() { + let Some(window) = &self.window_partition else { return Ok(required); - } + }; let child_physical_prop = rel_expr.derive_physical_prop_child(0)?; // Can't merge to shuffle @@ -115,8 +117,8 @@ impl Operator for Sort { return Ok(required); } - let partition_by = self - .window_partition + let partition_by = window + .partition_by .iter() .map(|s| s.scalar.clone()) .collect(); @@ -134,9 +136,9 @@ impl Operator for Sort { let mut required = required.clone(); required.distribution = Distribution::Serial; - if self.window_partition.is_empty() { + let Some(window) = &self.window_partition else { return Ok(vec![vec![required]]); - } + }; // Can't merge to shuffle let child_physical_prop = rel_expr.derive_physical_prop_child(0)?; @@ -144,8 +146,8 @@ impl Operator for Sort { return Ok(vec![vec![required]]); } - let partition_by = self - .window_partition + let partition_by = window + .partition_by .iter() .map(|s| s.scalar.clone()) .collect(); @@ -163,13 +165,13 @@ impl Operator for Sort { // Derive orderings let orderings = self.items.clone(); - let (orderings, partition_orderings) = if !self.window_partition.is_empty() { - ( + + let (orderings, partition_orderings) = match &self.window_partition { + Some(window) => ( vec![], - Some((self.window_partition.clone(), orderings.clone())), - ) - } else { - (self.items.clone(), None) + Some((window.partition_by.clone(), orderings.clone())), + ), + None => (self.items.clone(), None), }; Ok(Arc::new(RelationalProperty { diff --git a/src/query/sql/src/planner/plans/window.rs b/src/query/sql/src/planner/plans/window.rs index da9746d377a8f..b3fc44536ad72 100644 --- a/src/query/sql/src/planner/plans/window.rs +++ b/src/query/sql/src/planner/plans/window.rs @@ -283,3 +283,10 @@ impl WindowFuncType { } } } + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct WindowPartition { + pub partition_by: Vec, + pub top: Option, + pub func: WindowFuncType, +} diff --git a/tests/sqllogictests/suites/mode/standalone/explain/window.test b/tests/sqllogictests/suites/mode/standalone/explain/window.test index 971b3dd1dccc7..37044436493ed 100644 --- a/tests/sqllogictests/suites/mode/standalone/explain/window.test +++ b/tests/sqllogictests/suites/mode/standalone/explain/window.test @@ -57,8 +57,8 @@ CompoundBlockOperator(Project) × 1 Merge to Resize × 4 Transform Window × 1 TransformWindowPartitionCollect × 1 - ShuffleMergePartition × 1 - ShufflePartition × 1 + ShuffleMergePartition(Window) × 1 + ShufflePartition(Window) × 1 DeserializeDataTransform × 1 SyncReadParquetDataTransform × 1 BlockPartitionSource × 1 @@ -79,8 +79,8 @@ CompoundBlockOperator(Project) × 1 Merge to Resize × 4 Transform Window × 1 TransformWindowPartitionCollect × 1 - ShuffleMergePartition × 1 - ShufflePartition × 1 + ShuffleMergePartition(Window) × 1 + ShufflePartition(Window) × 1 DeserializeDataTransform × 1 SyncReadParquetDataTransform × 1 BlockPartitionSource × 1 @@ -367,8 +367,8 @@ CompoundBlockOperator(Project) × 1 LimitTransform × 1 Transform Window × 1 TransformWindowPartitionCollect × 1 - ShuffleMergePartition × 1 - ShufflePartition × 1 + ShuffleMergePartition(Window) × 1 + ShufflePartition(Window) × 1 DeserializeDataTransform × 1 SyncReadParquetDataTransform × 1 BlockPartitionSource × 1 @@ -385,8 +385,8 @@ CompoundBlockOperator(Project) × 1 LimitTransform × 1 Transform Window × 1 TransformWindowPartitionCollect × 1 - ShuffleMergePartition × 1 - ShufflePartition × 1 + ShuffleMergePartition(Window) × 1 + ShufflePartition(Window) × 1 DeserializeDataTransform × 1 SyncReadParquetDataTransform × 1 BlockPartitionSource × 1 @@ -404,8 +404,8 @@ CompoundBlockOperator(Project) × 1 LimitTransform × 1 Transform Window × 1 TransformWindowPartitionCollect × 1 - ShuffleMergePartition × 1 - ShufflePartition × 1 + ShuffleMergePartition(Window) × 1 + ShufflePartition(Window) × 1 DeserializeDataTransform × 1 SyncReadParquetDataTransform × 1 BlockPartitionSource × 1 @@ -418,8 +418,8 @@ CompoundBlockOperator(Project) × 1 LimitTransform × 1 Transform Window × 1 TransformWindowPartitionCollect × 1 - ShuffleMergePartition × 1 - ShufflePartition × 1 + ShuffleMergePartition(Window) × 1 + ShufflePartition(Window) × 1 DeserializeDataTransform × 1 SyncReadParquetDataTransform × 1 BlockPartitionSource × 1 @@ -432,8 +432,8 @@ CompoundBlockOperator(Project) × 1 LimitTransform × 1 Transform Window × 1 TransformWindowPartitionCollect × 1 - ShuffleMergePartition × 1 - ShufflePartition × 1 + ShuffleMergePartition(Window) × 1 + ShufflePartition(Window) × 1 DeserializeDataTransform × 1 SyncReadParquetDataTransform × 1 BlockPartitionSource × 1 @@ -452,8 +452,8 @@ CompoundBlockOperator(Project) × 1 Merge to Resize × 4 Transform Window × 1 TransformWindowPartitionCollect × 1 - ShuffleMergePartition × 1 - ShufflePartition × 1 + ShuffleMergePartition(Window) × 1 + ShufflePartition(Window) × 1 DeserializeDataTransform × 1 SyncReadParquetDataTransform × 1 BlockPartitionSource × 1 @@ -471,8 +471,8 @@ CompoundBlockOperator(Project) × 1 Merge to Resize × 4 Transform Window × 1 TransformWindowPartitionCollect × 1 - ShuffleMergePartition × 1 - ShufflePartition × 1 + ShuffleMergePartition(Window) × 1 + ShufflePartition(Window) × 1 TransformFilter × 1 AddInternalColumnsTransform × 1 DeserializeDataTransform × 1 @@ -535,8 +535,8 @@ CompoundBlockOperator(Project) × 1 TransformFilter × 1 Transform Window × 1 TransformWindowPartitionCollect × 1 - ShuffleMergePartition × 1 - ShufflePartition × 1 + ShuffleMergePartition(WindowTopN) × 1 + ShufflePartition(WindowTopN) × 1 DeserializeDataTransform × 1 SyncReadParquetDataTransform × 1 BlockPartitionSource × 1 @@ -565,8 +565,8 @@ CompoundBlockOperator(Project) × 1 Transform Window × 1 Transform Window × 1 TransformWindowPartitionCollect × 1 - ShuffleMergePartition × 1 - ShufflePartition × 1 + ShuffleMergePartition(Window) × 1 + ShufflePartition(Window) × 1 CompoundBlockOperator(Map) × 1 NumbersSourceTransform × 1 diff --git a/tests/sqllogictests/suites/tpcds/spill.test b/tests/sqllogictests/suites/tpcds/spill.test index 366c85121b6d5..de6c46c8f8cfa 100644 --- a/tests/sqllogictests/suites/tpcds/spill.test +++ b/tests/sqllogictests/suites/tpcds/spill.test @@ -47,25 +47,25 @@ statement ok drop table if exists t; statement ok -set max_block_size = 65536; +unset max_block_size; statement ok -set join_spilling_memory_ratio = 60; +unset join_spilling_memory_ratio; statement ok -set join_spilling_bytes_threshold_per_proc = 0; +unset join_spilling_bytes_threshold_per_proc; statement ok -set join_spilling_buffer_threshold_per_proc_mb = 512; +unset join_spilling_buffer_threshold_per_proc_mb; statement ok -set sort_spilling_memory_ratio = 60; +unset sort_spilling_memory_ratio; statement ok -set sort_spilling_bytes_threshold_per_proc = 0; +unset sort_spilling_bytes_threshold_per_proc; statement ok -set window_partition_spilling_memory_ratio = 60; +unset window_partition_spilling_memory_ratio; statement ok -set window_partition_spilling_bytes_threshold_per_proc = 0; \ No newline at end of file +unset window_partition_spilling_bytes_threshold_per_proc; \ No newline at end of file