diff --git a/src/common/base/src/base/mod.rs b/src/common/base/src/base/mod.rs index 7be5e319cac25..272f2cc1e30d6 100644 --- a/src/common/base/src/base/mod.rs +++ b/src/common/base/src/base/mod.rs @@ -22,6 +22,7 @@ mod singleton_instance; mod stop_handle; mod stoppable; mod string; +mod take_mut; mod uniq_id; pub use net::get_free_tcp_port; @@ -46,6 +47,7 @@ pub use string::mask_connection_info; pub use string::mask_string; pub use string::unescape_for_key; pub use string::unescape_string; +pub use take_mut::take_mut; pub use tokio; pub use uniq_id::GlobalSequence; pub use uniq_id::GlobalUniqName; diff --git a/src/common/base/src/base/take_mut.rs b/src/common/base/src/base/take_mut.rs new file mode 100644 index 0000000000000..6023e7f4e2a3d --- /dev/null +++ b/src/common/base/src/base/take_mut.rs @@ -0,0 +1,38 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::panic; + +use common_exception::Result; + +/// copy from https://docs.rs/take_mut/0.2.2/take_mut/fn.take.html with some modifications. +/// if a panic occurs, the entire process will be aborted, as there's no valid `T` to put back into the `&mut T`. +pub fn take_mut(mut_ref: &mut T, closure: F) -> Result<()> +where F: FnOnce(T) -> Result { + use std::ptr; + + unsafe { + let old_t = ptr::read(mut_ref); + let closure_result = panic::catch_unwind(panic::AssertUnwindSafe(|| closure(old_t))); + + match closure_result { + Ok(Ok(new_t)) => { + ptr::write(mut_ref, new_t); + Ok(()) + } + Ok(Err(e)) => Err(e), + Err(_) => ::std::process::abort(), + } + } +} diff --git a/src/query/expression/src/types.rs b/src/query/expression/src/types.rs index 80ba95e92f665..cbae779c3e18e 100755 --- a/src/query/expression/src/types.rs +++ b/src/query/expression/src/types.rs @@ -313,6 +313,10 @@ pub trait ValueType: Debug + Clone + PartialEq + Sized + 'static { builder: &'a mut ColumnBuilder, ) -> Option<&'a mut Self::ColumnBuilder>; + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option; + + fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option; + fn upcast_scalar(scalar: Self::Scalar) -> Scalar; fn upcast_column(col: Self::Column) -> Column; fn upcast_domain(domain: Self::Domain) -> Domain; diff --git a/src/query/expression/src/types/any.rs b/src/query/expression/src/types/any.rs index ae267ebef30ac..bc10f689ede14 100755 --- a/src/query/expression/src/types/any.rs +++ b/src/query/expression/src/types/any.rs @@ -64,6 +64,14 @@ impl ValueType for AnyType { Some(builder) } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + Some(builder) + } + + fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + Some(builder) + } + fn upcast_scalar(scalar: Self::Scalar) -> Scalar { scalar } diff --git a/src/query/expression/src/types/array.rs b/src/query/expression/src/types/array.rs index caf7e078a72eb..f80f4d39a7cf4 100755 --- a/src/query/expression/src/types/array.rs +++ b/src/query/expression/src/types/array.rs @@ -81,6 +81,38 @@ impl ValueType for ArrayType { None } + #[allow(clippy::manual_map)] + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + match builder { + ColumnBuilder::Array(inner) => { + let builder = T::try_downcast_owned_builder(inner.builder); + // ``` + // builder.map(|builder| ArrayColumnBuilder { + // builder, + // offsets: inner.offsets, + // }) + // ``` + // If we using the clippy recommend way like above, the compiler will complain: + // use of partially moved value: `inner`. + // That's rust borrow checker error, if we using the new borrow checker named polonius, + // everything goes fine, but polonius is very slow, so we allow manual map here. + if let Some(builder) = builder { + Some(ArrayColumnBuilder { + builder, + offsets: inner.offsets, + }) + } else { + None + } + } + _ => None, + } + } + + fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + Some(ColumnBuilder::Array(Box::new(builder.upcast()))) + } + fn upcast_scalar(scalar: Self::Scalar) -> Scalar { Scalar::Array(T::upcast_column(scalar)) } @@ -366,6 +398,13 @@ impl ArrayColumnBuilder { (self.offsets[0] as usize)..(self.offsets[1] as usize), ) } + + pub fn upcast(self) -> ArrayColumnBuilder { + ArrayColumnBuilder { + builder: T::try_upcast_column_builder(self.builder).unwrap(), + offsets: self.offsets, + } + } } impl ArrayColumnBuilder { diff --git a/src/query/expression/src/types/bitmap.rs b/src/query/expression/src/types/bitmap.rs index afe30de74c8bb..5b9788f9955b4 100644 --- a/src/query/expression/src/types/bitmap.rs +++ b/src/query/expression/src/types/bitmap.rs @@ -63,11 +63,22 @@ impl ValueType for BitmapType { builder: &'a mut ColumnBuilder, ) -> Option<&'a mut Self::ColumnBuilder> { match builder { - crate::ColumnBuilder::Bitmap(builder) => Some(builder), + ColumnBuilder::Bitmap(builder) => Some(builder), _ => None, } } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + match builder { + ColumnBuilder::Bitmap(builder) => Some(builder), + _ => None, + } + } + + fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + Some(ColumnBuilder::Bitmap(builder)) + } + fn try_downcast_domain(domain: &Domain) -> Option { if domain.is_undefined() { Some(()) diff --git a/src/query/expression/src/types/boolean.rs b/src/query/expression/src/types/boolean.rs index 51033b2e3d7da..4411ac495527d 100644 --- a/src/query/expression/src/types/boolean.rs +++ b/src/query/expression/src/types/boolean.rs @@ -70,11 +70,22 @@ impl ValueType for BooleanType { builder: &'a mut ColumnBuilder, ) -> Option<&'a mut Self::ColumnBuilder> { match builder { - crate::ColumnBuilder::Boolean(builder) => Some(builder), + ColumnBuilder::Boolean(builder) => Some(builder), _ => None, } } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + match builder { + ColumnBuilder::Boolean(builder) => Some(builder), + _ => None, + } + } + + fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + Some(ColumnBuilder::Boolean(builder)) + } + fn try_downcast_domain(domain: &Domain) -> Option { domain.as_boolean().map(BooleanDomain::clone) } diff --git a/src/query/expression/src/types/date.rs b/src/query/expression/src/types/date.rs index 31927540a63f6..7975a861cda23 100644 --- a/src/query/expression/src/types/date.rs +++ b/src/query/expression/src/types/date.rs @@ -103,6 +103,17 @@ impl ValueType for DateType { } } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + match builder { + ColumnBuilder::Date(builder) => Some(builder), + _ => None, + } + } + + fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + Some(ColumnBuilder::Date(builder)) + } + fn upcast_scalar(scalar: Self::Scalar) -> Scalar { Scalar::Date(scalar) } diff --git a/src/query/expression/src/types/decimal.rs b/src/query/expression/src/types/decimal.rs index d5b4022522cc6..d73c0f89deb84 100644 --- a/src/query/expression/src/types/decimal.rs +++ b/src/query/expression/src/types/decimal.rs @@ -89,6 +89,14 @@ impl ValueType for DecimalType { Num::try_downcast_builder(builder) } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + Num::try_downcast_owned_builder(builder) + } + + fn try_upcast_column_builder(_builder: Self::ColumnBuilder) -> Option { + None + } + fn upcast_scalar(scalar: Self::Scalar) -> Scalar { Num::upcast_scalar(scalar, Num::default_decimal_size()) } @@ -291,6 +299,8 @@ pub trait Decimal: fn try_downcast_column(column: &Column) -> Option<(Buffer, DecimalSize)>; fn try_downcast_builder<'a>(builder: &'a mut ColumnBuilder) -> Option<&'a mut Vec>; + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option>; + fn try_downcast_scalar(scalar: &DecimalScalar) -> Option; fn try_downcast_domain(domain: &DecimalDomain) -> Option>; @@ -477,6 +487,13 @@ impl Decimal for i128 { } } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option> { + match builder { + ColumnBuilder::Decimal(DecimalColumnBuilder::Decimal128(s, _)) => Some(s), + _ => None, + } + } + fn try_downcast_scalar<'a>(scalar: &DecimalScalar) -> Option { match scalar { DecimalScalar::Decimal128(val, _) => Some(*val), @@ -630,6 +647,13 @@ impl Decimal for i256 { } } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option> { + match builder { + ColumnBuilder::Decimal(DecimalColumnBuilder::Decimal256(s, _)) => Some(s), + _ => None, + } + } + fn try_downcast_scalar<'a>(scalar: &DecimalScalar) -> Option { match scalar { DecimalScalar::Decimal256(val, _) => Some(*val), diff --git a/src/query/expression/src/types/empty_array.rs b/src/query/expression/src/types/empty_array.rs index 35782d6643b92..7207c5533d4e6 100644 --- a/src/query/expression/src/types/empty_array.rs +++ b/src/query/expression/src/types/empty_array.rs @@ -76,6 +76,17 @@ impl ValueType for EmptyArrayType { } } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + match builder { + ColumnBuilder::EmptyArray { len } => Some(len), + _ => None, + } + } + + fn try_upcast_column_builder(len: Self::ColumnBuilder) -> Option { + Some(ColumnBuilder::EmptyArray { len }) + } + fn upcast_scalar(_: Self::Scalar) -> Scalar { Scalar::EmptyArray } diff --git a/src/query/expression/src/types/empty_map.rs b/src/query/expression/src/types/empty_map.rs index e3567884f2c37..2176d221d8fc8 100644 --- a/src/query/expression/src/types/empty_map.rs +++ b/src/query/expression/src/types/empty_map.rs @@ -76,6 +76,17 @@ impl ValueType for EmptyMapType { } } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + match builder { + ColumnBuilder::EmptyMap { len } => Some(len), + _ => None, + } + } + + fn try_upcast_column_builder(len: Self::ColumnBuilder) -> Option { + Some(ColumnBuilder::EmptyMap { len }) + } + fn upcast_scalar(_: Self::Scalar) -> Scalar { Scalar::EmptyMap } diff --git a/src/query/expression/src/types/generic.rs b/src/query/expression/src/types/generic.rs index cb356884432ef..b8ce249193994 100755 --- a/src/query/expression/src/types/generic.rs +++ b/src/query/expression/src/types/generic.rs @@ -67,6 +67,14 @@ impl ValueType for GenericType { Some(builder) } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + Some(builder) + } + + fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + Some(builder) + } + fn upcast_scalar(scalar: Self::Scalar) -> Scalar { scalar } diff --git a/src/query/expression/src/types/map.rs b/src/query/expression/src/types/map.rs index b03f882f3d43d..38d170e813f50 100755 --- a/src/query/expression/src/types/map.rs +++ b/src/query/expression/src/types/map.rs @@ -87,6 +87,14 @@ impl ValueType for KvPair { None } + fn try_downcast_owned_builder<'a>(_builder: ColumnBuilder) -> Option { + None + } + + fn try_upcast_column_builder(_builder: Self::ColumnBuilder) -> Option { + None + } + fn upcast_scalar((k, v): Self::Scalar) -> Scalar { Scalar::Tuple(vec![K::upcast_scalar(k), V::upcast_scalar(v)]) } @@ -351,6 +359,14 @@ impl ValueType for MapType { as ValueType>::try_downcast_builder(builder) } + fn try_downcast_owned_builder<'a>(builder: ColumnBuilder) -> Option { + as ValueType>::try_downcast_owned_builder(builder) + } + + fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + as ValueType>::try_upcast_column_builder(builder) + } + fn upcast_scalar(scalar: Self::Scalar) -> Scalar { Scalar::Map(KvPair::::upcast_column(scalar)) } diff --git a/src/query/expression/src/types/null.rs b/src/query/expression/src/types/null.rs index c49ca12d11e2d..78c5737934c32 100644 --- a/src/query/expression/src/types/null.rs +++ b/src/query/expression/src/types/null.rs @@ -75,11 +75,22 @@ impl ValueType for NullType { builder: &'a mut ColumnBuilder, ) -> Option<&'a mut Self::ColumnBuilder> { match builder { - crate::ColumnBuilder::Null { len } => Some(len), + ColumnBuilder::Null { len } => Some(len), _ => None, } } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + match builder { + ColumnBuilder::Null { len } => Some(len), + _ => None, + } + } + + fn try_upcast_column_builder(len: Self::ColumnBuilder) -> Option { + Some(ColumnBuilder::Null { len }) + } + fn upcast_scalar(_: Self::Scalar) -> Scalar { Scalar::Null } diff --git a/src/query/expression/src/types/nullable.rs b/src/query/expression/src/types/nullable.rs index 0b8bfd4bbf12e..78a267052bb7f 100755 --- a/src/query/expression/src/types/nullable.rs +++ b/src/query/expression/src/types/nullable.rs @@ -95,6 +95,38 @@ impl ValueType for NullableType { None } + #[allow(clippy::manual_map)] + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + match builder { + ColumnBuilder::Nullable(inner) => { + let builder = T::try_downcast_owned_builder(inner.builder); + // ``` + // builder.map(|builder| NullableColumnBuilder { + // builder, + // validity: inner.validity, + // }) + // ``` + // If we using the clippy recommend way like above, the compiler will complain: + // use of partially moved value: `inner`. + // That's rust borrow checker error, if we using the new borrow checker named polonius, + // everything goes fine, but polonius is very slow, so we allow manual map here. + if let Some(builder) = builder { + Some(NullableColumnBuilder { + builder, + validity: inner.validity, + }) + } else { + None + } + } + _ => None, + } + } + + fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + Some(ColumnBuilder::Nullable(Box::new(builder.upcast()))) + } + fn upcast_scalar(scalar: Self::Scalar) -> Scalar { match scalar { Some(scalar) => T::upcast_scalar(scalar), @@ -346,6 +378,13 @@ impl NullableColumnBuilder { None } } + + pub fn upcast(self) -> NullableColumnBuilder { + NullableColumnBuilder { + builder: T::try_upcast_column_builder(self.builder).unwrap(), + validity: self.validity, + } + } } impl NullableColumnBuilder { diff --git a/src/query/expression/src/types/number.rs b/src/query/expression/src/types/number.rs index ebed746b32556..13bb296e35f5b 100644 --- a/src/query/expression/src/types/number.rs +++ b/src/query/expression/src/types/number.rs @@ -136,6 +136,17 @@ impl ValueType for NumberType { } } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + match builder { + ColumnBuilder::Number(num) => Num::try_downcast_owned_builder(num), + _ => None, + } + } + + fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + Num::try_upcast_column_builder(builder) + } + fn upcast_scalar(scalar: Self::Scalar) -> Scalar { Scalar::Number(Num::upcast_scalar(scalar)) } @@ -757,6 +768,10 @@ pub trait Number: fn try_downcast_scalar(scalar: &NumberScalar) -> Option; fn try_downcast_column(col: &NumberColumn) -> Option>; fn try_downcast_builder(col: &mut NumberColumnBuilder) -> Option<&mut Vec>; + + fn try_downcast_owned_builder(col: NumberColumnBuilder) -> Option>; + + fn try_upcast_column_builder(builder: Vec) -> Option; fn try_downcast_domain(domain: &NumberDomain) -> Option>; fn upcast_scalar(scalar: Self) -> NumberScalar; fn upcast_column(col: Buffer) -> NumberColumn; @@ -789,6 +804,17 @@ impl Number for u8 { builder.as_u_int8_mut() } + fn try_downcast_owned_builder(builder: NumberColumnBuilder) -> Option> { + match builder { + NumberColumnBuilder::UInt8(b) => Some(b), + _ => None, + } + } + + fn try_upcast_column_builder(v: Vec) -> Option { + Some(ColumnBuilder::Number(NumberColumnBuilder::UInt8(v))) + } + fn try_downcast_domain(domain: &NumberDomain) -> Option> { domain.as_u_int8().cloned() } @@ -828,6 +854,17 @@ impl Number for u16 { builder.as_u_int16_mut() } + fn try_downcast_owned_builder(builder: NumberColumnBuilder) -> Option> { + match builder { + NumberColumnBuilder::UInt16(b) => Some(b), + _ => None, + } + } + + fn try_upcast_column_builder(v: Vec) -> Option { + Some(ColumnBuilder::Number(NumberColumnBuilder::UInt16(v))) + } + fn try_downcast_domain(domain: &NumberDomain) -> Option> { domain.as_u_int16().cloned() } @@ -868,6 +905,17 @@ impl Number for u32 { builder.as_u_int32_mut() } + fn try_downcast_owned_builder(builder: NumberColumnBuilder) -> Option> { + match builder { + NumberColumnBuilder::UInt32(b) => Some(b), + _ => None, + } + } + + fn try_upcast_column_builder(v: Vec) -> Option { + Some(ColumnBuilder::Number(NumberColumnBuilder::UInt32(v))) + } + fn try_downcast_domain(domain: &NumberDomain) -> Option> { domain.as_u_int32().cloned() } @@ -908,6 +956,17 @@ impl Number for u64 { builder.as_u_int64_mut() } + fn try_downcast_owned_builder(builder: NumberColumnBuilder) -> Option> { + match builder { + NumberColumnBuilder::UInt64(b) => Some(b), + _ => None, + } + } + + fn try_upcast_column_builder(v: Vec) -> Option { + Some(ColumnBuilder::Number(NumberColumnBuilder::UInt64(v))) + } + fn try_downcast_domain(domain: &NumberDomain) -> Option> { domain.as_u_int64().cloned() } @@ -948,6 +1007,17 @@ impl Number for i8 { builder.as_int8_mut() } + fn try_downcast_owned_builder(builder: NumberColumnBuilder) -> Option> { + match builder { + NumberColumnBuilder::Int8(b) => Some(b), + _ => None, + } + } + + fn try_upcast_column_builder(v: Vec) -> Option { + Some(ColumnBuilder::Number(NumberColumnBuilder::Int8(v))) + } + fn try_downcast_domain(domain: &NumberDomain) -> Option> { domain.as_int8().cloned() } @@ -988,6 +1058,17 @@ impl Number for i16 { builder.as_int16_mut() } + fn try_downcast_owned_builder(builder: NumberColumnBuilder) -> Option> { + match builder { + NumberColumnBuilder::Int16(b) => Some(b), + _ => None, + } + } + + fn try_upcast_column_builder(v: Vec) -> Option { + Some(ColumnBuilder::Number(NumberColumnBuilder::Int16(v))) + } + fn try_downcast_domain(domain: &NumberDomain) -> Option> { domain.as_int16().cloned() } @@ -1028,6 +1109,17 @@ impl Number for i32 { builder.as_int32_mut() } + fn try_downcast_owned_builder(builder: NumberColumnBuilder) -> Option> { + match builder { + NumberColumnBuilder::Int32(b) => Some(b), + _ => None, + } + } + + fn try_upcast_column_builder(v: Vec) -> Option { + Some(ColumnBuilder::Number(NumberColumnBuilder::Int32(v))) + } + fn try_downcast_domain(domain: &NumberDomain) -> Option> { domain.as_int32().cloned() } @@ -1068,6 +1160,17 @@ impl Number for i64 { builder.as_int64_mut() } + fn try_downcast_owned_builder(builder: NumberColumnBuilder) -> Option> { + match builder { + NumberColumnBuilder::Int64(b) => Some(b), + _ => None, + } + } + + fn try_upcast_column_builder(v: Vec) -> Option { + Some(ColumnBuilder::Number(NumberColumnBuilder::Int64(v))) + } + fn try_downcast_domain(domain: &NumberDomain) -> Option> { domain.as_int64().cloned() } @@ -1108,6 +1211,17 @@ impl Number for F32 { builder.as_float32_mut() } + fn try_downcast_owned_builder(builder: NumberColumnBuilder) -> Option> { + match builder { + NumberColumnBuilder::Float32(b) => Some(b), + _ => None, + } + } + + fn try_upcast_column_builder(v: Vec) -> Option { + Some(ColumnBuilder::Number(NumberColumnBuilder::Float32(v))) + } + fn try_downcast_domain(domain: &NumberDomain) -> Option> { domain.as_float32().cloned() } @@ -1156,6 +1270,17 @@ impl Number for F64 { builder.as_float64_mut() } + fn try_downcast_owned_builder(builder: NumberColumnBuilder) -> Option> { + match builder { + NumberColumnBuilder::Float64(b) => Some(b), + _ => None, + } + } + + fn try_upcast_column_builder(v: Vec) -> Option { + Some(ColumnBuilder::Number(NumberColumnBuilder::Float64(v))) + } + fn try_downcast_domain(domain: &NumberDomain) -> Option> { domain.as_float64().cloned() } diff --git a/src/query/expression/src/types/string.rs b/src/query/expression/src/types/string.rs index b47c4088e3889..ac88684b6a824 100644 --- a/src/query/expression/src/types/string.rs +++ b/src/query/expression/src/types/string.rs @@ -74,11 +74,22 @@ impl ValueType for StringType { builder: &'a mut ColumnBuilder, ) -> Option<&'a mut Self::ColumnBuilder> { match builder { - crate::ColumnBuilder::String(builder) => Some(builder), + ColumnBuilder::String(builder) => Some(builder), _ => None, } } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + match builder { + ColumnBuilder::String(builder) => Some(builder), + _ => None, + } + } + + fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + Some(ColumnBuilder::String(builder)) + } + fn upcast_scalar(scalar: Self::Scalar) -> Scalar { Scalar::String(scalar) } @@ -279,6 +290,7 @@ impl<'a> Iterator for StringIterator<'a> { } unsafe impl<'a> TrustedLen for StringIterator<'a> {} + unsafe impl<'a> std::iter::TrustedLen for StringIterator<'a> {} #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] diff --git a/src/query/expression/src/types/timestamp.rs b/src/query/expression/src/types/timestamp.rs index 0ac003d1019d4..dde8687408e32 100644 --- a/src/query/expression/src/types/timestamp.rs +++ b/src/query/expression/src/types/timestamp.rs @@ -58,6 +58,7 @@ pub fn check_timestamp(micros: i64) -> Result { Err("timestamp is out of range".to_string()) } } + #[derive(Debug, Clone, PartialEq, Eq)] pub struct TimestampType; @@ -109,6 +110,17 @@ impl ValueType for TimestampType { } } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + match builder { + ColumnBuilder::Timestamp(builder) => Some(builder), + _ => None, + } + } + + fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + Some(ColumnBuilder::Timestamp(builder)) + } + fn upcast_scalar(scalar: Self::Scalar) -> Scalar { Scalar::Timestamp(scalar) } diff --git a/src/query/expression/src/types/variant.rs b/src/query/expression/src/types/variant.rs index 0f683c659dbfe..99ba6f8a30835 100644 --- a/src/query/expression/src/types/variant.rs +++ b/src/query/expression/src/types/variant.rs @@ -82,11 +82,22 @@ impl ValueType for VariantType { builder: &'a mut ColumnBuilder, ) -> Option<&'a mut Self::ColumnBuilder> { match builder { - crate::ColumnBuilder::Variant(builder) => Some(builder), + ColumnBuilder::Variant(builder) => Some(builder), _ => None, } } + fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option { + match builder { + ColumnBuilder::Variant(builder) => Some(builder), + _ => None, + } + } + + fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + Some(ColumnBuilder::Variant(builder)) + } + fn upcast_scalar(scalar: Self::Scalar) -> Scalar { Scalar::Variant(scalar) } diff --git a/src/query/functions/src/aggregates/aggregate_avg.rs b/src/query/functions/src/aggregates/aggregate_avg.rs index 8d5318ec4d645..04c7bcfe41af1 100644 --- a/src/query/functions/src/aggregates/aggregate_avg.rs +++ b/src/query/functions/src/aggregates/aggregate_avg.rs @@ -293,8 +293,8 @@ pub fn try_create_aggregate_avg_function( } } _ => Err(ErrorCode::BadDataValueType(format!( - "AggregateSumFunction does not support type '{:?}'", - arguments[0] + "{} does not support type '{:?}'", + display_name, arguments[0] ))), }) } diff --git a/src/query/functions/src/aggregates/aggregate_sum.rs b/src/query/functions/src/aggregates/aggregate_sum.rs index 104ae081b0fc5..9261b7d776fdb 100644 --- a/src/query/functions/src/aggregates/aggregate_sum.rs +++ b/src/query/functions/src/aggregates/aggregate_sum.rs @@ -275,8 +275,8 @@ pub fn try_create_aggregate_sum_function( } } _ => Err(ErrorCode::BadDataValueType(format!( - "AggregateSumFunction does not support type '{:?}'", - arguments[0] + "{} does not support type '{:?}'", + display_name, arguments[0] ))), }) } diff --git a/src/query/functions/src/aggregates/aggregate_unary.rs b/src/query/functions/src/aggregates/aggregate_unary.rs index 0002d0939a7d4..306dcd082ca14 100644 --- a/src/query/functions/src/aggregates/aggregate_unary.rs +++ b/src/query/functions/src/aggregates/aggregate_unary.rs @@ -14,13 +14,18 @@ use std::alloc::Layout; use std::any::Any; +use std::any::TypeId; use std::fmt::Display; use std::fmt::Formatter; use std::marker::PhantomData; use std::sync::Arc; use common_arrow::arrow::bitmap::Bitmap; +use common_base::base::take_mut; use common_exception::Result; +use common_expression::types::decimal::Decimal128Type; +use common_expression::types::decimal::Decimal256Type; +use common_expression::types::decimal::DecimalColumnBuilder; use common_expression::types::DataType; use common_expression::types::ValueType; use common_expression::AggregateFunction; @@ -29,7 +34,6 @@ use common_expression::Column; use common_expression::ColumnBuilder; use common_expression::Scalar; use common_expression::StateAddr; - pub trait UnaryState: Send + Sync + Default where T: ValueType, @@ -129,6 +133,33 @@ where self.need_drop = need_drop; self } + + fn do_merge_result(&self, state: &mut S, builder: &mut ColumnBuilder) -> Result<()> { + match builder { + // current decimal implementation hard do upcast_builder, we do downcast manually. + ColumnBuilder::Decimal(b) => match b { + DecimalColumnBuilder::Decimal128(_, _) => { + debug_assert!(TypeId::of::() == TypeId::of::()); + let builder = R::try_downcast_builder(builder).unwrap(); + state.merge_result(builder, self.function_data.as_deref()) + } + DecimalColumnBuilder::Decimal256(_, _) => { + debug_assert!(TypeId::of::() == TypeId::of::()); + let builder = R::try_downcast_builder(builder).unwrap(); + state.merge_result(builder, self.function_data.as_deref()) + } + }, + // some `ValueType` like `NullableType` need ownership to downcast builder, + // so here we using an unsafe way to take the ownership of builder. + // See [`take_mut`] for details. + _ => take_mut(builder, |builder| { + let mut builder = R::try_downcast_owned_builder(builder).unwrap(); + state + .merge_result(&mut builder, self.function_data.as_deref()) + .map(|_| R::try_upcast_column_builder(builder).unwrap()) + }), + } + } } impl AggregateFunction for AggregateUnaryFunction @@ -226,9 +257,7 @@ where fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { let state: &mut S = place.get::(); - let builder = R::try_downcast_builder(builder).unwrap(); - state.merge_result(builder, self.function_data.as_deref())?; - Ok(()) + self.do_merge_result(state, builder) } fn batch_merge_result( @@ -239,8 +268,7 @@ where ) -> Result<()> { for place in places { let state: &mut S = place.next(offset).get::(); - let builder = R::try_downcast_builder(builder).unwrap(); - state.merge_result(builder, self.function_data.as_deref())?; + self.do_merge_result(state, builder)?; } Ok(()) }