Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,8 @@ required-features = ["math_expressions"]
harness = false
name = "floor_ceil"
required-features = ["math_expressions"]

[[bench]]
harness = false
name = "round"
required-features = ["math_expressions"]
154 changes: 154 additions & 0 deletions datafusion/functions/benches/round.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

extern crate criterion;

use arrow::datatypes::{DataType, Field, Float32Type, Float64Type};
use arrow::util::bench_util::create_primitive_array;
use criterion::{Criterion, SamplingMode, criterion_group, criterion_main};
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use datafusion_functions::math::round;
use std::hint::black_box;
use std::sync::Arc;
use std::time::Duration;

fn criterion_benchmark(c: &mut Criterion) {
let round_fn = round();
let config_options = Arc::new(ConfigOptions::default());

for size in [1024, 4096, 8192] {
let mut group = c.benchmark_group(format!("round size={size}"));
group.sampling_mode(SamplingMode::Flat);
group.sample_size(10);
group.measurement_time(Duration::from_secs(10));

// Float64 array benchmark
let f64_array = Arc::new(create_primitive_array::<Float64Type>(size, 0.1));
let batch_len = f64_array.len();
let f64_args = vec![
ColumnarValue::Array(f64_array),
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
];

group.bench_function("round_f64_array", |b| {
b.iter(|| {
let args_cloned = f64_args.clone();
black_box(
round_fn
.invoke_with_args(ScalarFunctionArgs {
args: args_cloned,
arg_fields: vec![
Field::new("a", DataType::Float64, true).into(),
Field::new("b", DataType::Int32, false).into(),
],
number_rows: batch_len,
return_field: Field::new("f", DataType::Float64, true).into(),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});

// Float32 array benchmark
let f32_array = Arc::new(create_primitive_array::<Float32Type>(size, 0.1));
let f32_args = vec![
ColumnarValue::Array(f32_array),
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
];

group.bench_function("round_f32_array", |b| {
b.iter(|| {
let args_cloned = f32_args.clone();
black_box(
round_fn
.invoke_with_args(ScalarFunctionArgs {
args: args_cloned,
arg_fields: vec![
Field::new("a", DataType::Float32, true).into(),
Field::new("b", DataType::Int32, false).into(),
],
number_rows: batch_len,
return_field: Field::new("f", DataType::Float32, true).into(),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});

// Scalar benchmark (the optimization we added)
let scalar_f64_args = vec![
ColumnarValue::Scalar(ScalarValue::Float64(Some(std::f64::consts::PI))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
];

group.bench_function("round_f64_scalar", |b| {
b.iter(|| {
let args_cloned = scalar_f64_args.clone();
black_box(
round_fn
.invoke_with_args(ScalarFunctionArgs {
args: args_cloned,
arg_fields: vec![
Field::new("a", DataType::Float64, false).into(),
Field::new("b", DataType::Int32, false).into(),
],
number_rows: 1,
return_field: Field::new("f", DataType::Float64, false)
.into(),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});

let scalar_f32_args = vec![
ColumnarValue::Scalar(ScalarValue::Float32(Some(std::f32::consts::PI))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
];

group.bench_function("round_f32_scalar", |b| {
b.iter(|| {
let args_cloned = scalar_f32_args.clone();
black_box(
round_fn
.invoke_with_args(ScalarFunctionArgs {
args: args_cloned,
arg_fields: vec![
Field::new("a", DataType::Float32, false).into(),
Field::new("b", DataType::Int32, false).into(),
],
number_rows: 1,
return_field: Field::new("f", DataType::Float32, false)
.into(),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});

group.finish();
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
200 changes: 198 additions & 2 deletions datafusion/functions/src/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use arrow::error::ArrowError;
use datafusion_common::types::{
NativeType, logical_float32, logical_float64, logical_int32,
};
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Expand Down Expand Up @@ -141,7 +141,67 @@ impl ScalarUDFImpl for RoundFunc {
&default_decimal_places
};

round_columnar(&args.args[0], decimal_places, args.number_rows)
// Scalar fast path for float and decimal types - avoid array conversion overhead

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) =
            (&args.args[0], decimal_places)
        {
            if value_scalar.is_null() || dp_scalar.is_null() {
                return ColumnarValue::Scalar(ScalarValue::Null)
                    .cast_to(args.return_type(), None);
            }

            let dp = if let ScalarValue::Int32(Some(dp)) = dp_scalar {
                *dp
            } else {
                return internal_err!(
                    "Unexpected datatype for decimal_places: {}",
                    dp_scalar.data_type()
                );
            };

            match value_scalar {
                ScalarValue::Float32(Some(v)) => {
                    let rounded = round_float(*v, dp)?;
                    Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
                }
                ScalarValue::Float64(Some(v)) => {
                    let rounded = round_float(*v, dp)?;
                    Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
                }
                ScalarValue::Decimal128(Some(v), precision, scale) => {
                    let rounded = round_decimal(*v, *scale, dp)?;
                    let scalar =
                        ScalarValue::Decimal128(Some(rounded), *precision, *scale);
                    Ok(ColumnarValue::Scalar(scalar))
                }
                ScalarValue::Decimal256(Some(v), precision, scale) => {
                    let rounded = round_decimal(*v, *scale, dp)?;
                    let scalar =
                        ScalarValue::Decimal256(Some(rounded), *precision, *scale);
                    Ok(ColumnarValue::Scalar(scalar))
                }
                ScalarValue::Decimal64(Some(v), precision, scale) => {
                    let rounded = round_decimal(*v, *scale, dp)?;
                    let scalar =
                        ScalarValue::Decimal64(Some(rounded), *precision, *scale);
                    Ok(ColumnarValue::Scalar(scalar))
                }
                ScalarValue::Decimal32(Some(v), precision, scale) => {
                    let rounded = round_decimal(*v, *scale, dp)?;
                    let scalar =
                        ScalarValue::Decimal32(Some(rounded), *precision, *scale);
                    Ok(ColumnarValue::Scalar(scalar))
                }
                _ => {
                    internal_err!(
                        "Unexpected datatype for value: {}",
                        value_scalar.data_type()
                    )
                }
            }
        } else {
            round_columnar(&args.args[0], decimal_places, args.number_rows)
        }

Cleaner way of doing this

  • Using internal_err which are more appropriate here than exec_err
  • Collapse null handling using ScalarValue::is_null and ColumnarValue::cast_to
  • Don't need to map the error of round_float and round_decimal because using ? does that for us

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this look much better.

if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) =
(&args.args[0], decimal_places)
{
if value_scalar.is_null() || dp_scalar.is_null() {
return ColumnarValue::Scalar(ScalarValue::Null)
.cast_to(args.return_type(), None);
}

let dp = if let ScalarValue::Int32(Some(dp)) = dp_scalar {
*dp
} else {
return internal_err!(
"Unexpected datatype for decimal_places: {}",
dp_scalar.data_type()
);
};

match value_scalar {
ScalarValue::Float32(Some(v)) => {
let rounded = round_float(*v, dp)?;
Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
}
ScalarValue::Float64(Some(v)) => {
let rounded = round_float(*v, dp)?;
Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
}
ScalarValue::Decimal128(Some(v), precision, scale) => {
let rounded = round_decimal(*v, *scale, dp)?;
let scalar =
ScalarValue::Decimal128(Some(rounded), *precision, *scale);
Ok(ColumnarValue::Scalar(scalar))
}
ScalarValue::Decimal256(Some(v), precision, scale) => {
let rounded = round_decimal(*v, *scale, dp)?;
let scalar =
ScalarValue::Decimal256(Some(rounded), *precision, *scale);
Ok(ColumnarValue::Scalar(scalar))
}
ScalarValue::Decimal64(Some(v), precision, scale) => {
let rounded = round_decimal(*v, *scale, dp)?;
let scalar =
ScalarValue::Decimal64(Some(rounded), *precision, *scale);
Ok(ColumnarValue::Scalar(scalar))
}
ScalarValue::Decimal32(Some(v), precision, scale) => {
let rounded = round_decimal(*v, *scale, dp)?;
let scalar =
ScalarValue::Decimal32(Some(rounded), *precision, *scale);
Ok(ColumnarValue::Scalar(scalar))
}
_ => {
internal_err!(
"Unexpected datatype for value: {}",
value_scalar.data_type()
)
}
}
} else {
round_columnar(&args.args[0], decimal_places, args.number_rows)
}
}

fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
Expand Down Expand Up @@ -434,4 +494,140 @@ mod test {
Err(DataFusionError::ArrowError(_, _)) | Err(DataFusionError::Execution(_))
));
}

// Tests for scalar fast path
use super::RoundFunc;
use arrow::datatypes::{DataType, Field};
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl};

fn round_scalar(
value: ScalarValue,
decimal_places: Option<i32>,
) -> Result<ScalarValue, DataFusionError> {
let round_func = RoundFunc::new();
let config_options = Arc::new(ConfigOptions::default());

let dp = decimal_places.unwrap_or(0);
let args = vec![
ColumnarValue::Scalar(value.clone()),
ColumnarValue::Scalar(ScalarValue::Int32(Some(dp))),
];

let return_type = round_func.return_type(&[value.data_type()])?;

let result = round_func.invoke_with_args(ScalarFunctionArgs {
args,
arg_fields: vec![
Field::new("a", value.data_type(), true).into(),
Field::new("b", DataType::Int32, false).into(),
],
number_rows: 1,
return_field: Field::new("f", return_type, true).into(),
config_options,
})?;

match result {
ColumnarValue::Scalar(s) => Ok(s),
ColumnarValue::Array(a) => ScalarValue::try_from_array(&a, 0),
}
}

#[test]
fn test_round_scalar_f64() {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be SLTs

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the scalar unit tests, these are already tested in the SLTs.

// Test basic rounding
let result = round_scalar(ScalarValue::Float64(Some(125.2345)), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Float64(Some(125.23)));

// Test negative value
let result =
round_scalar(ScalarValue::Float64(Some(-125.2345)), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Float64(Some(-125.23)));

// Test negative decimal places
let result =
round_scalar(ScalarValue::Float64(Some(12345.55)), Some(-1)).unwrap();
assert_eq!(result, ScalarValue::Float64(Some(12350.0)));

// Test null
let result = round_scalar(ScalarValue::Float64(None), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Float64(None));
}

#[test]
fn test_round_scalar_f32() {
// Test basic rounding
let result = round_scalar(ScalarValue::Float32(Some(125.2345)), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Float32(Some(125.23)));

// Test negative value
let result =
round_scalar(ScalarValue::Float32(Some(-125.2345)), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Float32(Some(-125.23)));

// Test null
let result = round_scalar(ScalarValue::Float32(None), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Float32(None));
}

#[test]
fn test_round_scalar_decimal128() {
// Test basic rounding - 314159 with scale 5 = 3.14159, rounds to 3.14 = 314000
let result =
round_scalar(ScalarValue::Decimal128(Some(314159), 10, 5), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Decimal128(Some(314000), 10, 5));

// Test negative value - -3.14159 rounds to -3.14 = -314000
let result =
round_scalar(ScalarValue::Decimal128(Some(-314159), 10, 5), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Decimal128(Some(-314000), 10, 5));

// Test null
let result = round_scalar(ScalarValue::Decimal128(None, 10, 5), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Decimal128(None, 10, 5));
}

#[test]
fn test_round_scalar_decimal256() {
use arrow::datatypes::i256;

// Test basic rounding - 314159 with scale 5 = 3.14159, rounds to 3.14 = 314000
let result = round_scalar(
ScalarValue::Decimal256(Some(i256::from(314159)), 50, 5),
Some(2),
)
.unwrap();
assert_eq!(
result,
ScalarValue::Decimal256(Some(i256::from(314000)), 50, 5)
);

// Test null
let result = round_scalar(ScalarValue::Decimal256(None, 50, 5), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Decimal256(None, 50, 5));
}

#[test]
fn test_round_scalar_decimal64() {
// Test basic rounding - 314159 with scale 5 = 3.14159, rounds to 3.14 = 314000
let result =
round_scalar(ScalarValue::Decimal64(Some(314159), 10, 5), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Decimal64(Some(314000), 10, 5));

// Test null
let result = round_scalar(ScalarValue::Decimal64(None, 10, 5), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Decimal64(None, 10, 5));
}

#[test]
fn test_round_scalar_decimal32() {
// Test basic rounding - 31416 with scale 4 = 3.1416, rounds to 3.14 = 31400
let result =
round_scalar(ScalarValue::Decimal32(Some(31416), 7, 4), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Decimal32(Some(31400), 7, 4));

// Test null
let result = round_scalar(ScalarValue::Decimal32(None, 7, 4), Some(2)).unwrap();
assert_eq!(result, ScalarValue::Decimal32(None, 7, 4));
}
}