Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(query): refactor function params to be scalar #14079

Merged
merged 3 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/query/expression/src/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use crate::types::nullable::NullableDomain;
use crate::types::BooleanType;
use crate::types::DataType;
use crate::types::NullableType;
use crate::types::NumberScalar;
use crate::values::Column;
use crate::values::ColumnBuilder;
use crate::values::Scalar;
Expand Down Expand Up @@ -750,7 +751,10 @@ impl<'a> Evaluator<'a> {
};

let params = if let DataType::Decimal(ty) = dest_type.remove_nullable() {
vec![ty.precision() as usize, ty.scale() as usize]
vec![
Scalar::Number(NumberScalar::Int64(ty.precision() as _)),
Scalar::Number(NumberScalar::Int64(ty.scale() as _)),
]
} else {
vec![]
};
Expand Down Expand Up @@ -1618,7 +1622,10 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> {
};

let params = if let DataType::Decimal(ty) = dest_type {
vec![ty.precision() as usize, ty.scale() as usize]
vec![
Scalar::Number(NumberScalar::Int64(ty.precision() as _)),
Scalar::Number(NumberScalar::Int64(ty.scale() as _)),
]
} else {
vec![]
};
Expand Down
2 changes: 1 addition & 1 deletion src/query/expression/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub enum RawExpr<Index: ColumnIndex = usize> {
FunctionCall {
span: Span,
name: String,
params: Vec<usize>,
params: Vec<Scalar>,
args: Vec<RawExpr<Index>>,
},
LambdaFunctionCall {
Expand Down
10 changes: 5 additions & 5 deletions src/query/expression/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub type AutoCastRules<'a> = &'a [(DataType, DataType)];
///
/// The first argument is the const parameters and the second argument is the types of arguments.
pub trait FunctionFactory =
Fn(&[usize], &[DataType]) -> Option<Arc<Function>> + Send + Sync + 'static;
Fn(&[Scalar], &[DataType]) -> Option<Arc<Function>> + Send + Sync + 'static;

pub struct Function {
pub signature: FunctionSignature,
Expand Down Expand Up @@ -131,7 +131,7 @@ pub enum FunctionID {
Factory {
name: String,
id: usize,
params: Vec<usize>,
params: Vec<Scalar>,
args_type: Vec<DataType>,
},
}
Expand Down Expand Up @@ -327,7 +327,7 @@ impl FunctionRegistry {
pub fn search_candidates<Index: ColumnIndex>(
&self,
name: &str,
params: &[usize],
params: &[Scalar],
args: &[Expr<Index>],
) -> Vec<(FunctionID, Arc<Function>)> {
let name = name.to_lowercase();
Expand Down Expand Up @@ -522,7 +522,7 @@ impl FunctionID {
}
}

pub fn params(&self) -> &[usize] {
pub fn params(&self) -> &[Scalar] {
match self {
FunctionID::Builtin { .. } => &[],
FunctionID::Factory { params, .. } => params.as_slice(),
Expand Down Expand Up @@ -558,7 +558,7 @@ impl<'a> EvalContext<'a> {
pub fn render_error(
&self,
span: Span,
params: &[usize],
params: &[Scalar],
args: &[Value<AnyType>],
func_name: &str,
) -> Result<()> {
Expand Down
12 changes: 7 additions & 5 deletions src/query/expression/src/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,12 @@ pub fn check<Index: ColumnIndex>(
} else {
new_args.push(Expr::Constant {
span: None,
scalar: Scalar::Number(NumberScalar::Int64(scale)),
scalar: Scalar::Number(scale.into()),
data_type: Int64Type::data_type(),
})
}
scale = scale.clamp(-76, 76);
let add_on_scale = (scale + 76) as usize;
let params = vec![add_on_scale];
let params = vec![Scalar::Number(scale.into())];
return check_function(*span, name, &params, &args_expr, fn_registry);
}

Expand Down Expand Up @@ -206,7 +205,10 @@ pub fn check_cast<Index: ColumnIndex>(
// fast path to eval function for cast
if let Some(cast_fn) = get_simple_cast_function(is_try, dest_type) {
let params = if let DataType::Decimal(ty) = dest_type {
vec![ty.precision() as usize, ty.scale() as usize]
vec![
Scalar::Number(NumberScalar::Int64(ty.precision() as _)),
Scalar::Number(NumberScalar::Int64(ty.scale() as _)),
]
} else {
vec![]
};
Expand Down Expand Up @@ -286,7 +288,7 @@ pub fn check_number<Index: ColumnIndex, T: Number>(
pub fn check_function<Index: ColumnIndex>(
span: Span,
name: &str,
params: &[usize],
params: &[Scalar],
args: &[Expr<Index>],
fn_registry: &FunctionRegistry,
) -> Result<Expr<Index>> {
Expand Down
11 changes: 11 additions & 0 deletions src/query/expression/src/types/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use borsh::BorshSerialize;
use databend_common_arrow::arrow::buffer::Buffer;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_io::display_decimal_128;
use databend_common_io::display_decimal_256;
use enum_as_inner::EnumAsInner;
use ethnum::i256;
use ethnum::AsI256;
Expand Down Expand Up @@ -307,6 +309,7 @@ pub trait Decimal:
fn from_i128<U: Into<i128>>(value: U) -> Self;

fn de_binary(bytes: &mut &[u8]) -> Self;
fn display(self, scale: u8) -> String;

fn to_float32(self, scale: u8) -> f32;
fn to_float64(self, scale: u8) -> f64;
Expand Down Expand Up @@ -453,6 +456,10 @@ impl Decimal for i128 {
i128::from_le_bytes(bs)
}

fn display(self, scale: u8) -> String {
display_decimal_128(self, scale)
}

fn to_float32(self, scale: u8) -> f32 {
let div = 10_f32.powi(scale as i32);
self as f32 / div
Expand Down Expand Up @@ -618,6 +625,10 @@ impl Decimal for i256 {
i256::from_le_bytes(bs)
}

fn display(self, scale: u8) -> String {
display_decimal_256(self, scale)
}

fn to_float32(self, scale: u8) -> f32 {
let div = 10_f32.powi(scale as i32);
self.as_f32() / div
Expand Down
8 changes: 8 additions & 0 deletions src/query/expression/src/types/number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,14 @@ impl NumberScalar {
}
}

impl<T> From<T> for NumberScalar
where T: Number
{
fn from(value: T) -> Self {
T::upcast_scalar(value)
}
}

impl NumberColumn {
pub fn len(&self) -> usize {
crate::with_number_type!(|NUM_TYPE| match self {
Expand Down
17 changes: 17 additions & 0 deletions src/query/expression/src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,23 @@ impl Scalar {
_ => unreachable!("is_positive() called on non-numeric scalar"),
}
}

pub fn get_i64(&self) -> Option<i64> {
match self {
Scalar::Number(n) => match n {
NumberScalar::Int8(x) => Some(*x as _),
NumberScalar::Int16(x) => Some(*x as _),
NumberScalar::Int32(x) => Some(*x as _),
NumberScalar::Int64(x) => Some(*x as _),
NumberScalar::UInt8(x) => Some(*x as _),
NumberScalar::UInt16(x) => Some(*x as _),
NumberScalar::UInt32(x) => Some(*x as _),
NumberScalar::UInt64(x) => i64::try_from(*x).ok(),
_ => None,
},
_ => None,
}
}
}

impl<'a> ScalarRef<'a> {
Expand Down
124 changes: 16 additions & 108 deletions src/query/functions/src/scalars/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,15 @@ use std::ops::BitXor;
use std::sync::Arc;

use databend_common_arrow::arrow::bitmap::Bitmap;
use databend_common_expression::types::decimal::Decimal;
use databend_common_expression::types::decimal::DecimalColumn;
use databend_common_expression::types::decimal::DecimalDomain;
use databend_common_expression::types::decimal::DecimalType;
use databend_common_expression::types::nullable::NullableColumn;
use databend_common_expression::types::nullable::NullableDomain;
use databend_common_expression::types::number::Number;
use databend_common_expression::types::number::NumberType;
use databend_common_expression::types::number::F64;
use databend_common_expression::types::string::StringColumnBuilder;
use databend_common_expression::types::AnyType;
use databend_common_expression::types::ArgType;
use databend_common_expression::types::DataType;
use databend_common_expression::types::DecimalDataType;
use databend_common_expression::types::NullableType;
Expand All @@ -51,23 +49,19 @@ use databend_common_expression::values::ValueRef;
use databend_common_expression::vectorize_1_arg;
use databend_common_expression::vectorize_with_builder_1_arg;
use databend_common_expression::vectorize_with_builder_2_arg;
use databend_common_expression::with_decimal_mapped_type;
use databend_common_expression::with_float_mapped_type;
use databend_common_expression::with_integer_mapped_type;
use databend_common_expression::with_number_mapped_type;
use databend_common_expression::with_number_mapped_type_without_64;
use databend_common_expression::with_unsigned_integer_mapped_type;
use databend_common_expression::Column;
use databend_common_expression::ColumnBuilder;
use databend_common_expression::Domain;
use databend_common_expression::EvalContext;
use databend_common_expression::Function;
use databend_common_expression::FunctionDomain;
use databend_common_expression::FunctionEval;
use databend_common_expression::FunctionRegistry;
use databend_common_expression::FunctionSignature;
use databend_common_expression::Scalar;
use databend_common_io::display_decimal_128;
use databend_common_io::display_decimal_256;
use ethnum::i256;
use lexical_core::FormattedSize;
use num_traits::AsPrimitive;
Expand All @@ -76,6 +70,7 @@ use super::arithmetic_modulo::vectorize_modulo;
use super::decimal::register_decimal_to_int;
use crate::scalars::decimal::register_decimal_arithmetic;
use crate::scalars::decimal::register_decimal_to_float;
use crate::scalars::decimal::register_decimal_to_string;

pub fn register(registry: &mut FunctionRegistry) {
registry.register_aliases("plus", &["add"]);
Expand Down Expand Up @@ -775,7 +770,7 @@ pub fn register_decimal_minus(registry: &mut FunctionRegistry) {
}
_ => unreachable!(),
}),
eval: Box::new(move |args, _tx| unary_minus_decimal(args, arg_type.clone())),
eval: Box::new(move |args, ctx| unary_minus_decimal(args, arg_type.clone(), ctx)),
},
};

Expand All @@ -787,34 +782,20 @@ pub fn register_decimal_minus(registry: &mut FunctionRegistry) {
});
}

fn unary_minus_decimal(args: &[ValueRef<AnyType>], arg_type: DataType) -> Value<AnyType> {
fn unary_minus_decimal(
args: &[ValueRef<AnyType>],
arg_type: DataType,
ctx: &mut EvalContext,
) -> Value<AnyType> {
let arg = &args[0];
let mut is_scalar = false;
let column = match arg {
ValueRef::Column(column) => column.clone(),
ValueRef::Scalar(s) => {
is_scalar = true;
let builder = ColumnBuilder::repeat(s, 1, &arg_type);
builder.build()
}
};

let result = match column {
Column::Decimal(DecimalColumn::Decimal128(buf, size)) => {
DecimalColumn::Decimal128(buf.into_iter().map(|x| -x).collect(), size)
}
Column::Decimal(DecimalColumn::Decimal256(buf, size)) => {
DecimalColumn::Decimal256(buf.into_iter().map(|x| -x).collect(), size)
let arg_type = arg_type.as_decimal().unwrap();
with_decimal_mapped_type!(|DECIMAL_TYPE| match arg_type {
DecimalDataType::DECIMAL_TYPE(size) => {
type Type = DecimalType<DECIMAL_TYPE>;
let arg = arg.try_downcast().unwrap();
vectorize_1_arg::<Type, Type>(|t, _| -t)(arg, ctx).upcast_decimal(*size)
}
_ => unreachable!(),
};

if is_scalar {
let scalar = result.index(0).unwrap();
Value::Scalar(Scalar::Decimal(scalar))
} else {
Value::Column(Column::Decimal(result))
}
})
}

fn register_string_to_number(registry: &mut FunctionRegistry) {
Expand Down Expand Up @@ -953,76 +934,3 @@ pub fn register_number_to_string(registry: &mut FunctionRegistry) {
});
}
}

fn register_decimal_to_string(registry: &mut FunctionRegistry) {
// decimal to string
registry.register_function_factory("to_string", |_params, args_type| {
if args_type.len() != 1 {
return None;
}

let arg_type = args_type[0].remove_nullable();
if !arg_type.is_decimal() {
return None;
}

Some(Arc::new(Function {
signature: FunctionSignature {
name: "to_string".to_string(),
args_type: vec![arg_type.clone()],
return_type: StringType::data_type(),
},
eval: FunctionEval::Scalar {
calc_domain: Box::new(|_, _| FunctionDomain::Full),
eval: Box::new(move |args, tx| decimal_to_string(args, arg_type.clone(), tx)),
},
}))
});
}

fn decimal_to_string(
args: &[ValueRef<AnyType>],
from_type: DataType,
_ctx: &mut EvalContext,
) -> Value<AnyType> {
let arg = &args[0];

let mut is_scalar = false;
let column = match arg {
ValueRef::Column(column) => column.clone(),
ValueRef::Scalar(s) => {
is_scalar = true;
let builder = ColumnBuilder::repeat(s, 1, &from_type);
builder.build()
}
};

let from_type = from_type.as_decimal().unwrap();

let column = match from_type {
DecimalDataType::Decimal128(_) => {
let (buffer, from_size) = i128::try_downcast_column(&column).unwrap();
let mut builder = StringColumnBuilder::with_capacity(buffer.len(), buffer.len() * 10);
for x in buffer {
builder.put_str(&display_decimal_128(x, from_size.scale));
builder.commit_row();
}
builder
}
DecimalDataType::Decimal256(_) => {
let (buffer, from_size) = i256::try_downcast_column(&column).unwrap();
let mut builder = StringColumnBuilder::with_capacity(buffer.len(), buffer.len() * 10);
for x in buffer {
builder.put_str(&display_decimal_256(x, from_size.scale));
builder.commit_row();
}
builder
}
};

if is_scalar {
Value::Scalar(Scalar::String(column.build_scalar()))
} else {
Value::Column(Column::String(column.build()))
}
}
Loading
Loading