diff --git a/src/query/expression/src/evaluator.rs b/src/query/expression/src/evaluator.rs index 82145643d517a..37a7c06a7df5f 100644 --- a/src/query/expression/src/evaluator.rs +++ b/src/query/expression/src/evaluator.rs @@ -38,6 +38,7 @@ use crate::types::nullable::NullableDomain; use crate::types::BooleanType; use crate::types::DataType; use crate::types::NullableType; +use crate::types::NumberScalar; use crate::values::Column; use crate::values::ColumnBuilder; use crate::values::Scalar; @@ -750,7 +751,10 @@ impl<'a> Evaluator<'a> { }; let params = if let DataType::Decimal(ty) = dest_type.remove_nullable() { - vec![ty.precision() as usize, ty.scale() as usize] + vec![ + Scalar::Number(NumberScalar::Int64(ty.precision() as _)), + Scalar::Number(NumberScalar::Int64(ty.scale() as _)), + ] } else { vec![] }; @@ -1618,7 +1622,10 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> { }; let params = if let DataType::Decimal(ty) = dest_type { - vec![ty.precision() as usize, ty.scale() as usize] + vec![ + Scalar::Number(NumberScalar::Int64(ty.precision() as _)), + Scalar::Number(NumberScalar::Int64(ty.scale() as _)), + ] } else { vec![] }; diff --git a/src/query/expression/src/expression.rs b/src/query/expression/src/expression.rs index 9cab633d49c17..1a32ae332490e 100644 --- a/src/query/expression/src/expression.rs +++ b/src/query/expression/src/expression.rs @@ -60,7 +60,7 @@ pub enum RawExpr { FunctionCall { span: Span, name: String, - params: Vec, + params: Vec, args: Vec>, }, LambdaFunctionCall { diff --git a/src/query/expression/src/function.rs b/src/query/expression/src/function.rs index cc9aa9b4ca13b..592b742cae517 100755 --- a/src/query/expression/src/function.rs +++ b/src/query/expression/src/function.rs @@ -49,7 +49,7 @@ pub type AutoCastRules<'a> = &'a [(DataType, DataType)]; /// /// The first argument is the const parameters and the second argument is the types of arguments. pub trait FunctionFactory = - Fn(&[usize], &[DataType]) -> Option> + Send + Sync + 'static; + Fn(&[Scalar], &[DataType]) -> Option> + Send + Sync + 'static; pub struct Function { pub signature: FunctionSignature, @@ -131,7 +131,7 @@ pub enum FunctionID { Factory { name: String, id: usize, - params: Vec, + params: Vec, args_type: Vec, }, } @@ -327,7 +327,7 @@ impl FunctionRegistry { pub fn search_candidates( &self, name: &str, - params: &[usize], + params: &[Scalar], args: &[Expr], ) -> Vec<(FunctionID, Arc)> { let name = name.to_lowercase(); @@ -522,7 +522,7 @@ impl FunctionID { } } - pub fn params(&self) -> &[usize] { + pub fn params(&self) -> &[Scalar] { match self { FunctionID::Builtin { .. } => &[], FunctionID::Factory { params, .. } => params.as_slice(), @@ -558,7 +558,7 @@ impl<'a> EvalContext<'a> { pub fn render_error( &self, span: Span, - params: &[usize], + params: &[Scalar], args: &[Value], func_name: &str, ) -> Result<()> { diff --git a/src/query/expression/src/type_check.rs b/src/query/expression/src/type_check.rs index d1b1f887cb927..9fdcfc4340223 100755 --- a/src/query/expression/src/type_check.rs +++ b/src/query/expression/src/type_check.rs @@ -143,13 +143,12 @@ pub fn check( } else { new_args.push(Expr::Constant { span: None, - scalar: Scalar::Number(NumberScalar::Int64(scale)), + scalar: Scalar::Number(scale.into()), data_type: Int64Type::data_type(), }) } scale = scale.clamp(-76, 76); - let add_on_scale = (scale + 76) as usize; - let params = vec![add_on_scale]; + let params = vec![Scalar::Number(scale.into())]; return check_function(*span, name, ¶ms, &args_expr, fn_registry); } @@ -206,7 +205,10 @@ pub fn check_cast( // fast path to eval function for cast if let Some(cast_fn) = get_simple_cast_function(is_try, dest_type) { let params = if let DataType::Decimal(ty) = dest_type { - vec![ty.precision() as usize, ty.scale() as usize] + vec![ + Scalar::Number(NumberScalar::Int64(ty.precision() as _)), + Scalar::Number(NumberScalar::Int64(ty.scale() as _)), + ] } else { vec![] }; @@ -286,7 +288,7 @@ pub fn check_number( pub fn check_function( span: Span, name: &str, - params: &[usize], + params: &[Scalar], args: &[Expr], fn_registry: &FunctionRegistry, ) -> Result> { diff --git a/src/query/expression/src/types/decimal.rs b/src/query/expression/src/types/decimal.rs index 91233032fe169..4c5b19697fb3a 100644 --- a/src/query/expression/src/types/decimal.rs +++ b/src/query/expression/src/types/decimal.rs @@ -21,6 +21,8 @@ use borsh::BorshSerialize; use databend_common_arrow::arrow::buffer::Buffer; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_io::display_decimal_128; +use databend_common_io::display_decimal_256; use enum_as_inner::EnumAsInner; use ethnum::i256; use ethnum::AsI256; @@ -307,6 +309,7 @@ pub trait Decimal: fn from_i128>(value: U) -> Self; fn de_binary(bytes: &mut &[u8]) -> Self; + fn display(self, scale: u8) -> String; fn to_float32(self, scale: u8) -> f32; fn to_float64(self, scale: u8) -> f64; @@ -453,6 +456,10 @@ impl Decimal for i128 { i128::from_le_bytes(bs) } + fn display(self, scale: u8) -> String { + display_decimal_128(self, scale) + } + fn to_float32(self, scale: u8) -> f32 { let div = 10_f32.powi(scale as i32); self as f32 / div @@ -618,6 +625,10 @@ impl Decimal for i256 { i256::from_le_bytes(bs) } + fn display(self, scale: u8) -> String { + display_decimal_256(self, scale) + } + fn to_float32(self, scale: u8) -> f32 { let div = 10_f32.powi(scale as i32); self.as_f32() / div diff --git a/src/query/expression/src/types/number.rs b/src/query/expression/src/types/number.rs index 58fe6a919dcad..93f8902c22712 100644 --- a/src/query/expression/src/types/number.rs +++ b/src/query/expression/src/types/number.rs @@ -502,6 +502,14 @@ impl NumberScalar { } } +impl From for NumberScalar +where T: Number +{ + fn from(value: T) -> Self { + T::upcast_scalar(value) + } +} + impl NumberColumn { pub fn len(&self) -> usize { crate::with_number_type!(|NUM_TYPE| match self { diff --git a/src/query/expression/src/values.rs b/src/query/expression/src/values.rs index 2a43d27c67674..aedc7aaa84781 100755 --- a/src/query/expression/src/values.rs +++ b/src/query/expression/src/values.rs @@ -408,6 +408,23 @@ impl Scalar { _ => unreachable!("is_positive() called on non-numeric scalar"), } } + + pub fn get_i64(&self) -> Option { + match self { + Scalar::Number(n) => match n { + NumberScalar::Int8(x) => Some(*x as _), + NumberScalar::Int16(x) => Some(*x as _), + NumberScalar::Int32(x) => Some(*x as _), + NumberScalar::Int64(x) => Some(*x as _), + NumberScalar::UInt8(x) => Some(*x as _), + NumberScalar::UInt16(x) => Some(*x as _), + NumberScalar::UInt32(x) => Some(*x as _), + NumberScalar::UInt64(x) => i64::try_from(*x).ok(), + _ => None, + }, + _ => None, + } + } } impl<'a> ScalarRef<'a> { diff --git a/src/query/functions/src/scalars/arithmetic.rs b/src/query/functions/src/scalars/arithmetic.rs index 1319dacc70e69..3958ddd1a1929 100644 --- a/src/query/functions/src/scalars/arithmetic.rs +++ b/src/query/functions/src/scalars/arithmetic.rs @@ -20,9 +20,8 @@ use std::ops::BitXor; use std::sync::Arc; use databend_common_arrow::arrow::bitmap::Bitmap; -use databend_common_expression::types::decimal::Decimal; -use databend_common_expression::types::decimal::DecimalColumn; use databend_common_expression::types::decimal::DecimalDomain; +use databend_common_expression::types::decimal::DecimalType; use databend_common_expression::types::nullable::NullableColumn; use databend_common_expression::types::nullable::NullableDomain; use databend_common_expression::types::number::Number; @@ -30,7 +29,6 @@ use databend_common_expression::types::number::NumberType; use databend_common_expression::types::number::F64; use databend_common_expression::types::string::StringColumnBuilder; use databend_common_expression::types::AnyType; -use databend_common_expression::types::ArgType; use databend_common_expression::types::DataType; use databend_common_expression::types::DecimalDataType; use databend_common_expression::types::NullableType; @@ -51,13 +49,12 @@ use databend_common_expression::values::ValueRef; use databend_common_expression::vectorize_1_arg; use databend_common_expression::vectorize_with_builder_1_arg; use databend_common_expression::vectorize_with_builder_2_arg; +use databend_common_expression::with_decimal_mapped_type; use databend_common_expression::with_float_mapped_type; use databend_common_expression::with_integer_mapped_type; use databend_common_expression::with_number_mapped_type; use databend_common_expression::with_number_mapped_type_without_64; use databend_common_expression::with_unsigned_integer_mapped_type; -use databend_common_expression::Column; -use databend_common_expression::ColumnBuilder; use databend_common_expression::Domain; use databend_common_expression::EvalContext; use databend_common_expression::Function; @@ -65,9 +62,6 @@ use databend_common_expression::FunctionDomain; use databend_common_expression::FunctionEval; use databend_common_expression::FunctionRegistry; use databend_common_expression::FunctionSignature; -use databend_common_expression::Scalar; -use databend_common_io::display_decimal_128; -use databend_common_io::display_decimal_256; use ethnum::i256; use lexical_core::FormattedSize; use num_traits::AsPrimitive; @@ -76,6 +70,7 @@ use super::arithmetic_modulo::vectorize_modulo; use super::decimal::register_decimal_to_int; use crate::scalars::decimal::register_decimal_arithmetic; use crate::scalars::decimal::register_decimal_to_float; +use crate::scalars::decimal::register_decimal_to_string; pub fn register(registry: &mut FunctionRegistry) { registry.register_aliases("plus", &["add"]); @@ -775,7 +770,7 @@ pub fn register_decimal_minus(registry: &mut FunctionRegistry) { } _ => unreachable!(), }), - eval: Box::new(move |args, _tx| unary_minus_decimal(args, arg_type.clone())), + eval: Box::new(move |args, ctx| unary_minus_decimal(args, arg_type.clone(), ctx)), }, }; @@ -787,34 +782,20 @@ pub fn register_decimal_minus(registry: &mut FunctionRegistry) { }); } -fn unary_minus_decimal(args: &[ValueRef], arg_type: DataType) -> Value { +fn unary_minus_decimal( + args: &[ValueRef], + arg_type: DataType, + ctx: &mut EvalContext, +) -> Value { let arg = &args[0]; - let mut is_scalar = false; - let column = match arg { - ValueRef::Column(column) => column.clone(), - ValueRef::Scalar(s) => { - is_scalar = true; - let builder = ColumnBuilder::repeat(s, 1, &arg_type); - builder.build() - } - }; - - let result = match column { - Column::Decimal(DecimalColumn::Decimal128(buf, size)) => { - DecimalColumn::Decimal128(buf.into_iter().map(|x| -x).collect(), size) - } - Column::Decimal(DecimalColumn::Decimal256(buf, size)) => { - DecimalColumn::Decimal256(buf.into_iter().map(|x| -x).collect(), size) + let arg_type = arg_type.as_decimal().unwrap(); + with_decimal_mapped_type!(|DECIMAL_TYPE| match arg_type { + DecimalDataType::DECIMAL_TYPE(size) => { + type Type = DecimalType; + let arg = arg.try_downcast().unwrap(); + vectorize_1_arg::(|t, _| -t)(arg, ctx).upcast_decimal(*size) } - _ => unreachable!(), - }; - - if is_scalar { - let scalar = result.index(0).unwrap(); - Value::Scalar(Scalar::Decimal(scalar)) - } else { - Value::Column(Column::Decimal(result)) - } + }) } fn register_string_to_number(registry: &mut FunctionRegistry) { @@ -953,76 +934,3 @@ pub fn register_number_to_string(registry: &mut FunctionRegistry) { }); } } - -fn register_decimal_to_string(registry: &mut FunctionRegistry) { - // decimal to string - registry.register_function_factory("to_string", |_params, args_type| { - if args_type.len() != 1 { - return None; - } - - let arg_type = args_type[0].remove_nullable(); - if !arg_type.is_decimal() { - return None; - } - - Some(Arc::new(Function { - signature: FunctionSignature { - name: "to_string".to_string(), - args_type: vec![arg_type.clone()], - return_type: StringType::data_type(), - }, - eval: FunctionEval::Scalar { - calc_domain: Box::new(|_, _| FunctionDomain::Full), - eval: Box::new(move |args, tx| decimal_to_string(args, arg_type.clone(), tx)), - }, - })) - }); -} - -fn decimal_to_string( - args: &[ValueRef], - from_type: DataType, - _ctx: &mut EvalContext, -) -> Value { - let arg = &args[0]; - - let mut is_scalar = false; - let column = match arg { - ValueRef::Column(column) => column.clone(), - ValueRef::Scalar(s) => { - is_scalar = true; - let builder = ColumnBuilder::repeat(s, 1, &from_type); - builder.build() - } - }; - - let from_type = from_type.as_decimal().unwrap(); - - let column = match from_type { - DecimalDataType::Decimal128(_) => { - let (buffer, from_size) = i128::try_downcast_column(&column).unwrap(); - let mut builder = StringColumnBuilder::with_capacity(buffer.len(), buffer.len() * 10); - for x in buffer { - builder.put_str(&display_decimal_128(x, from_size.scale)); - builder.commit_row(); - } - builder - } - DecimalDataType::Decimal256(_) => { - let (buffer, from_size) = i256::try_downcast_column(&column).unwrap(); - let mut builder = StringColumnBuilder::with_capacity(buffer.len(), buffer.len() * 10); - for x in buffer { - builder.put_str(&display_decimal_256(x, from_size.scale)); - builder.commit_row(); - } - builder - } - }; - - if is_scalar { - Value::Scalar(Scalar::String(column.build_scalar())) - } else { - Value::Column(Column::String(column.build())) - } -} diff --git a/src/query/functions/src/scalars/decimal/cast.rs b/src/query/functions/src/scalars/decimal/cast.rs index 58b3d9d83d605..e1c4e9b3168a2 100644 --- a/src/query/functions/src/scalars/decimal/cast.rs +++ b/src/query/functions/src/scalars/decimal/cast.rs @@ -17,12 +17,14 @@ use std::sync::Arc; use databend_common_expression::serialize::read_decimal_with_size; use databend_common_expression::types::decimal::*; +use databend_common_expression::types::string::StringColumnBuilder; use databend_common_expression::types::*; use databend_common_expression::vectorize_1_arg; use databend_common_expression::vectorize_with_builder_1_arg; use databend_common_expression::with_decimal_mapped_type; use databend_common_expression::with_integer_mapped_type; use databend_common_expression::with_number_mapped_type; +use databend_common_expression::Column; use databend_common_expression::Domain; use databend_common_expression::EvalContext; use databend_common_expression::FromData; @@ -32,6 +34,7 @@ use databend_common_expression::FunctionDomain; use databend_common_expression::FunctionEval; use databend_common_expression::FunctionRegistry; use databend_common_expression::FunctionSignature; +use databend_common_expression::Scalar; use databend_common_expression::Value; use databend_common_expression::ValueRef; use ethnum::i256; @@ -40,7 +43,7 @@ use ordered_float::OrderedFloat; // int float to decimal pub fn register_to_decimal(registry: &mut FunctionRegistry) { - let factory = |params: &[usize], args_type: &[DataType]| { + let factory = |params: &[Scalar], args_type: &[DataType]| { if args_type.len() != 1 { return None; } @@ -58,8 +61,8 @@ pub fn register_to_decimal(registry: &mut FunctionRegistry) { } let decimal_size = DecimalSize { - precision: params[0] as u8, - scale: params[1] as u8, + precision: params[0].get_i64()? as _, + scale: params[1].get_i64()? as _, }; let decimal_type = DecimalDataType::from_size(decimal_size).ok()?; @@ -108,7 +111,7 @@ pub(crate) fn register_decimal_to_float(registry: &mut FunctionRegist let is_f32 = matches!(data_type, DataType::Number(NumberDataType::Float32)); - let factory = |_params: &[usize], args_type: &[DataType], data_type: DataType| { + let factory = |_params: &[Scalar], args_type: &[DataType], data_type: DataType| { if args_type.len() != 1 { return None; } @@ -203,7 +206,7 @@ pub(crate) fn register_decimal_to_int(registry: &mut FunctionRegistry let name = format!("to_{}", T::data_type().to_string().to_lowercase()); let try_name = format!("try_to_{}", T::data_type().to_string().to_lowercase()); - let factory = |_params: &[usize], args_type: &[DataType]| { + let factory = |_params: &[Scalar], args_type: &[DataType]| { if args_type.len() != 1 { return None; } @@ -263,6 +266,68 @@ pub(crate) fn register_decimal_to_int(registry: &mut FunctionRegistry }); } +pub(crate) fn register_decimal_to_string(registry: &mut FunctionRegistry) { + // decimal to string + let factory = |_params: &[Scalar], args_type: &[DataType]| { + if args_type.len() != 1 { + return None; + } + + let arg_type = args_type[0].remove_nullable(); + if !arg_type.is_decimal() { + return None; + } + + let function = Function { + signature: FunctionSignature { + name: "to_string".to_string(), + args_type: vec![arg_type.clone()], + return_type: StringType::data_type(), + }, + eval: FunctionEval::Scalar { + calc_domain: Box::new(|_, _| FunctionDomain::Full), + eval: Box::new(move |args, tx| decimal_to_string(args, arg_type.clone(), tx)), + }, + }; + + if args_type[0].is_nullable() { + Some(Arc::new(function.passthrough_nullable())) + } else { + Some(Arc::new(function)) + } + }; + registry.register_function_factory("to_string", factory); +} + +fn decimal_to_string( + args: &[ValueRef], + from_type: DataType, + _ctx: &mut EvalContext, +) -> Value { + let arg = &args[0]; + let from_type = from_type.as_decimal().unwrap(); + + with_decimal_mapped_type!(|DECIMAL_TYPE| match from_type { + DecimalDataType::DECIMAL_TYPE(from_size) => { + let arg: ValueRef> = arg.try_downcast().unwrap(); + + match arg { + ValueRef::Column(col) => { + let mut builder = StringColumnBuilder::with_capacity(col.len(), col.len() * 10); + for x in DecimalType::::iter_column(&col) { + builder.put_str(&DECIMAL_TYPE::display(x, from_size.scale)); + builder.commit_row(); + } + Value::Column(Column::String(builder.build())) + } + ValueRef::Scalar(x) => Value::Scalar(Scalar::String( + DECIMAL_TYPE::display(x, from_size.scale).into(), + )), + } + } + }) +} + fn convert_to_decimal( arg: &ValueRef, ctx: &mut EvalContext, diff --git a/src/query/functions/src/scalars/decimal/math.rs b/src/query/functions/src/scalars/decimal/math.rs index 7ed5976b858b8..886785dfcc5e6 100644 --- a/src/query/functions/src/scalars/decimal/math.rs +++ b/src/query/functions/src/scalars/decimal/math.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::cmp::Ord; use std::ops::*; use std::sync::Arc; @@ -26,12 +25,13 @@ use databend_common_expression::FunctionDomain; use databend_common_expression::FunctionEval; use databend_common_expression::FunctionRegistry; use databend_common_expression::FunctionSignature; +use databend_common_expression::Scalar; use databend_common_expression::Value; use databend_common_expression::ValueRef; use ethnum::i256; pub fn register_decimal_math(registry: &mut FunctionRegistry) { - let factory = |params: &[usize], args_type: &[DataType], round_mode: RoundMode| { + let factory = |params: &[Scalar], args_type: &[DataType], round_mode: RoundMode| { if args_type.is_empty() { return None; } @@ -46,7 +46,7 @@ pub fn register_decimal_math(registry: &mut FunctionRegistry) { let scale = if params.is_empty() { 0 } else { - params[0] as i64 - 76 + params[0].get_i64()? }; let decimal_size = DecimalSize { diff --git a/src/query/functions/src/scalars/decimal/mod.rs b/src/query/functions/src/scalars/decimal/mod.rs index 12407ca97278b..c55ea2d36f56b 100644 --- a/src/query/functions/src/scalars/decimal/mod.rs +++ b/src/query/functions/src/scalars/decimal/mod.rs @@ -20,6 +20,7 @@ mod math; pub(crate) use arithmetic::register_decimal_arithmetic; pub(crate) use cast::register_decimal_to_float; pub(crate) use cast::register_decimal_to_int; +pub(crate) use cast::register_decimal_to_string; pub(crate) use cast::register_to_decimal; pub(crate) use comparison::register_decimal_compare_op; pub(crate) use math::register_decimal_math; diff --git a/src/query/functions/src/scalars/other.rs b/src/query/functions/src/scalars/other.rs index 85287d5070a86..46c02e2ca2fa0 100644 --- a/src/query/functions/src/scalars/other.rs +++ b/src/query/functions/src/scalars/other.rs @@ -349,7 +349,8 @@ fn register_grouping(registry: &mut FunctionRegistry) { return None; } - let params = params.to_vec(); + let params: Vec = params.iter().map(|p| p.get_i64().unwrap() as _).collect(); + Some(Arc::new(Function { signature: FunctionSignature { name: "grouping".to_string(), diff --git a/src/query/functions/src/scalars/tuple.rs b/src/query/functions/src/scalars/tuple.rs index c3ae24e0eb550..00aa6df387db7 100644 --- a/src/query/functions/src/scalars/tuple.rs +++ b/src/query/functions/src/scalars/tuple.rs @@ -81,7 +81,8 @@ pub fn register(registry: &mut FunctionRegistry) { registry.register_function_factory("get", |params, args_type| { // Tuple index starts from 1 - let idx = params.first()?.checked_sub(1)?; + let idx = (params.first()?.get_i64()? as usize).checked_sub(1)?; + let fields_ty = match args_type.first()? { DataType::Tuple(tys) => tys, _ => return None, @@ -117,7 +118,7 @@ pub fn register(registry: &mut FunctionRegistry) { registry.register_function_factory("get", |params, args_type| { // Tuple index starts from 1 - let idx = params.first()?.checked_sub(1)?; + let idx = usize::try_from(params.first()?.get_i64()? - 1).ok()?; let fields_ty = match args_type.first()? { DataType::Nullable(box DataType::Tuple(tys)) => tys, _ => return None, @@ -179,7 +180,7 @@ pub fn register(registry: &mut FunctionRegistry) { registry.register_function_factory("get", |params, args_type| { // Tuple index starts from 1 - let idx = params.first()?.checked_sub(1)?; + let idx = usize::try_from(params.first()?.get_i64()? - 1).ok()?; let fields_ty = match args_type.first()? { DataType::Nullable(box DataType::Tuple(tys)) => tys, _ => return None, diff --git a/src/query/functions/src/srfs/variant.rs b/src/query/functions/src/srfs/variant.rs index 606616f917610..6bd1a1cf25eaf 100644 --- a/src/query/functions/src/srfs/variant.rs +++ b/src/query/functions/src/srfs/variant.rs @@ -269,7 +269,7 @@ pub fn register(registry: &mut FunctionRegistry) { { return None; } - let params = params.to_vec(); + let params: Vec = params.iter().map(|x| x.get_i64().unwrap()).collect(); Some(Arc::new(Function { signature: FunctionSignature { @@ -707,7 +707,7 @@ impl FlattenGenerator { } } - fn generate(&mut self, seq: u64, input: &[u8], path: &str, params: &[usize]) -> Vec { + fn generate(&mut self, seq: u64, input: &[u8], path: &str, params: &[i64]) -> Vec { // Only columns required by parent plan need a builder. let mut key_builder = if params.is_empty() || params.contains(&2) { Some(NullableColumnBuilder::::with_capacity(0, &[])) diff --git a/src/query/functions/tests/it/aggregates/mod.rs b/src/query/functions/tests/it/aggregates/mod.rs index 37f80a146defd..b25935554969e 100644 --- a/src/query/functions/tests/it/aggregates/mod.rs +++ b/src/query/functions/tests/it/aggregates/mod.rs @@ -21,7 +21,6 @@ use bumpalo::Bump; use comfy_table::Table; use databend_common_exception::Result; use databend_common_expression::type_check; -use databend_common_expression::types::number::NumberScalar; use databend_common_expression::types::AnyType; use databend_common_expression::types::DataType; use databend_common_expression::BlockEntry; @@ -85,11 +84,6 @@ pub fn run_agg_ast( .collect::>() .unwrap(); - let params = params - .iter() - .map(|p| Scalar::Number(NumberScalar::UInt64(*p as u64))) - .collect(); - // Convert the delimiter of string_agg to params let params = if name.eq_ignore_ascii_case("string_agg") && args.len() == 2 { let val = args[1].0.as_scalar().unwrap(); diff --git a/src/query/functions/tests/it/scalars/parser.rs b/src/query/functions/tests/it/scalars/parser.rs index 68fdf0dab3739..3f5d0cd7910b9 100644 --- a/src/query/functions/tests/it/scalars/parser.rs +++ b/src/query/functions/tests/it/scalars/parser.rs @@ -149,8 +149,8 @@ pub fn transform_expr(ast: AExpr, columns: &[(&str, DataType)]) -> RawExpr { params: params .into_iter() .map(|param| match param { - ASTLiteral::UInt64(u) => u as usize, - ASTLiteral::Decimal256 { .. } => 0_usize, + ASTLiteral::UInt64(u) => Scalar::Number((u as i64).into()), + ASTLiteral::Decimal256 { .. } => Scalar::Number(0i64.into()), _ => unimplemented!(), }) .collect(), @@ -349,9 +349,13 @@ pub fn transform_expr(ast: AExpr, columns: &[(&str, DataType)]) -> RawExpr { }, ]), MapAccessor::DotNumber { key } => { - (vec![key as usize], vec![transform_expr(*expr, columns)]) + (vec![key as i64], vec![transform_expr(*expr, columns)]) } }; + let params = params + .into_iter() + .map(|x| Scalar::Number(x.into())) + .collect(); RawExpr::FunctionCall { span, name: "get".to_string(), diff --git a/src/query/functions/tests/it/scalars/testdata/math.txt b/src/query/functions/tests/it/scalars/testdata/math.txt index d959a713f93c6..90b313cefd105 100644 --- a/src/query/functions/tests/it/scalars/testdata/math.txt +++ b/src/query/functions/tests/it/scalars/testdata/math.txt @@ -281,7 +281,7 @@ evaluation (internal): ast : round(-1.23) raw expr : round(minus(1.23)) -checked expr : round(76)(minus(1.23_d128(3,2))) +checked expr : round(0)(minus(1.23_d128(3,2))) optimized expr : -1_d128(3,0) output type : Decimal(3, 0) output domain : {-1..=-1} @@ -290,7 +290,7 @@ output : -1 ast : round(1.298, 1) raw expr : round(1.298, 1) -checked expr : round(77)(1.298_d128(4,3), 1_u8) +checked expr : round(1)(1.298_d128(4,3), 1_u8) optimized expr : 1.3_d128(4,1) output type : Decimal(4, 1) output domain : {1.3..=1.3} @@ -299,7 +299,7 @@ output : 1.3 ast : round(1.298, 0) raw expr : round(1.298, 0) -checked expr : round(76)(1.298_d128(4,3), 0_u8) +checked expr : round(0)(1.298_d128(4,3), 0_u8) optimized expr : 1_d128(4,0) output type : Decimal(4, 0) output domain : {1..=1} @@ -308,7 +308,7 @@ output : 1 ast : round(23.298, -1) raw expr : round(23.298, minus(1)) -checked expr : round(75)(23.298_d128(5,3), minus(1_u8)) +checked expr : round(-1)(23.298_d128(5,3), minus(1_u8)) optimized expr : 20_d128(5,0) output type : Decimal(5, 0) output domain : {20..=20} @@ -317,7 +317,7 @@ output : 20 ast : round(0.12345678901234567890123456789012345, 35) raw expr : round(0.12345678901234567890123456789012345, 35) -checked expr : round(111)(0.12345678901234567890123456789012345_d128(35,35), 35_u8) +checked expr : round(35)(0.12345678901234567890123456789012345_d128(35,35), 35_u8) optimized expr : 0.12345678901234567890123456789012345_d128(35,35) output type : Decimal(35, 35) output domain : {0.12345678901234567890123456789012345..=0.12345678901234567890123456789012345} @@ -410,7 +410,7 @@ evaluation (internal): ast : truncate(1.223, 1) raw expr : truncate(1.223, 1) -checked expr : truncate(77)(1.223_d128(4,3), 1_u8) +checked expr : truncate(1)(1.223_d128(4,3), 1_u8) optimized expr : 1.2_d128(4,1) output type : Decimal(4, 1) output domain : {1.2..=1.2} @@ -419,7 +419,7 @@ output : 1.2 ast : truncate(1.999) raw expr : truncate(1.999) -checked expr : truncate(76)(1.999_d128(4,3)) +checked expr : truncate(0)(1.999_d128(4,3)) optimized expr : 1_d128(4,0) output type : Decimal(4, 0) output domain : {1..=1} @@ -428,7 +428,7 @@ output : 1 ast : truncate(1.999, 1) raw expr : truncate(1.999, 1) -checked expr : truncate(77)(1.999_d128(4,3), 1_u8) +checked expr : truncate(1)(1.999_d128(4,3), 1_u8) optimized expr : 1.9_d128(4,1) output type : Decimal(4, 1) output domain : {1.9..=1.9} @@ -446,7 +446,7 @@ output : 100 ast : truncate(10.28*100, 0) raw expr : truncate(multiply(10.28, 100), 0) -checked expr : truncate(76)(multiply(to_decimal(7, 2)(10.28_d128(4,2)), to_decimal(7, 0)(100_u8)), 0_u8) +checked expr : truncate(0)(multiply(to_decimal(7, 2)(10.28_d128(4,2)), to_decimal(7, 0)(100_u8)), 0_u8) optimized expr : 1028_d128(7,0) output type : Decimal(7, 0) output domain : {1028..=1028} diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index 4bb10ad3c6969..d69587387ddce 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -282,7 +282,7 @@ impl<'a> AggregateRewriter<'a> { let mut replaced_params = Vec::with_capacity(function.arguments.len()); for arg in &function.arguments { if let Some(index) = agg_info.group_items_map.get(arg) { - replaced_params.push(*index); + replaced_params.push(*index as _); } else { return Err(ErrorCode::BadArguments( "Arguments of grouping should be group by expressions", diff --git a/src/query/sql/src/planner/binder/table.rs b/src/query/sql/src/planner/binder/table.rs index bc56357f61d13..6a3646c578799 100644 --- a/src/query/sql/src/planner/binder/table.rs +++ b/src/query/sql/src/planner/binder/table.rs @@ -378,7 +378,7 @@ impl Binder { let field_expr = ScalarExpr::FunctionCall(FunctionCall { span: *span, func_name: "get".to_string(), - params: vec![i + 1], + params: vec![(i + 1) as i64], arguments: vec![scalar.clone()], }); let data_type = field_expr.data_type()?; diff --git a/src/query/sql/src/planner/optimizer/rule/utils/constant.rs b/src/query/sql/src/planner/optimizer/rule/utils/constant.rs index 2fdad26bac068..16d123e04b8b2 100644 --- a/src/query/sql/src/planner/optimizer/rule/utils/constant.rs +++ b/src/query/sql/src/planner/optimizer/rule/utils/constant.rs @@ -204,7 +204,7 @@ pub fn remove_trivial_type_cast(left: ScalarExpr, right: ScalarExpr) -> (ScalarE (**argument).clone(), ScalarExpr::ConstantExpr(ConstantExpr { span: *span, - value: Scalar::Number(NumberScalar::Int64(v)), + value: Scalar::Number(v.into()), }), ); } diff --git a/src/query/sql/src/planner/plans/scalar_expr.rs b/src/query/sql/src/planner/plans/scalar_expr.rs index 3c7ed84010c63..d8e05ea9c6c52 100644 --- a/src/query/sql/src/planner/plans/scalar_expr.rs +++ b/src/query/sql/src/planner/plans/scalar_expr.rs @@ -526,7 +526,7 @@ pub struct FunctionCall { #[educe(Hash(ignore), PartialEq(ignore), Eq(ignore))] pub span: Span, pub func_name: String, - pub params: Vec, + pub params: Vec, pub arguments: Vec, } diff --git a/src/query/sql/src/planner/semantic/lowering.rs b/src/query/sql/src/planner/semantic/lowering.rs index dfaad7b135128..22c9162dea10e 100644 --- a/src/query/sql/src/planner/semantic/lowering.rs +++ b/src/query/sql/src/planner/semantic/lowering.rs @@ -22,6 +22,7 @@ use databend_common_expression::ColumnIndex; use databend_common_expression::DataSchema; use databend_common_expression::Expr; use databend_common_expression::RawExpr; +use databend_common_expression::Scalar; use databend_common_functions::BUILTIN_FUNCTIONS; use crate::binder::ColumnBindingBuilder; @@ -229,7 +230,11 @@ impl ScalarExpr { ScalarExpr::FunctionCall(func) => RawExpr::FunctionCall { span: func.span, name: func.func_name.clone(), - params: func.params.clone(), + params: func + .params + .iter() + .map(|x| Scalar::Number((*x).into())) + .collect(), args: func.arguments.iter().map(ScalarExpr::as_raw_expr).collect(), }, ScalarExpr::CastExpr(cast) => RawExpr::Cast { diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index e47550676a060..7b55453bff26e 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -861,7 +861,7 @@ impl<'a> TypeChecker<'a> { let params = params .iter() .map(|literal| match literal { - Literal::UInt64(n) => Ok(*n as usize), + Literal::UInt64(n) => Ok(*n as i64), lit => Err(ErrorCode::SemanticError(format!( "invalid parameter {lit} for scalar function" )) @@ -1054,7 +1054,7 @@ impl<'a> TypeChecker<'a> { if let databend_common_expression::Scalar::Number(NumberScalar::UInt8(0)) = expr.value { args[1] = ConstantExpr { span: expr.span, - value: databend_common_expression::Scalar::Number(NumberScalar::Int64(1)), + value: databend_common_expression::Scalar::Number(1i64.into()), } .into(); } @@ -1819,7 +1819,7 @@ impl<'a> TypeChecker<'a> { &mut self, span: Span, func_name: &str, - params: Vec, + params: Vec, arguments: &[&Expr], ) -> Result> { // Check if current function is a virtual function, e.g. `database`, `version` @@ -1885,7 +1885,7 @@ impl<'a> TypeChecker<'a> { &self, span: Span, func_name: &str, - params: Vec, + params: Vec, args: Vec, ) -> Result> { // Type check @@ -1893,7 +1893,7 @@ impl<'a> TypeChecker<'a> { let raw_expr = RawExpr::FunctionCall { span, name: func_name.to_string(), - params: params.clone(), + params: params.iter().map(|x| Scalar::Number((*x).into())).collect(), args: arguments, }; let expr = type_check::check(&raw_expr, &BUILTIN_FUNCTIONS)?; @@ -2947,7 +2947,7 @@ impl<'a> TypeChecker<'a> { let value = FunctionCall { span, - params: vec![idx + 1], + params: vec![(idx + 1) as _], arguments: vec![scalar.clone()], func_name: "get".to_string(), } @@ -3073,7 +3073,7 @@ impl<'a> TypeChecker<'a> { scalar = FunctionCall { span: expr.span(), func_name: "get".to_string(), - params: vec![idx], + params: vec![idx as _], arguments: vec![scalar.clone()], } .into(); @@ -3189,7 +3189,7 @@ impl<'a> TypeChecker<'a> { while let Some((idx, table_data_type)) = index_with_types.pop_front() { scalar = FunctionCall { span, - params: vec![idx], + params: vec![idx as _], arguments: vec![scalar.clone()], func_name: "get".to_string(), }