Skip to content

Commit 70e7e62

Browse files
committed
More rigorous treatment of floats in tests
1 parent fc49e8d commit 70e7e62

27 files changed

+1141
-1093
lines changed

Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ arrow-ipc = { version = "53.3.0", default-features = false, features = [
9292
arrow-ord = { version = "53.3.0", default-features = false }
9393
arrow-schema = { version = "53.3.0", default-features = false }
9494
async-trait = "0.1.73"
95-
bigdecimal = "0.4.6"
9695
bytes = "1.4"
9796
chrono = { version = "0.4.38", default-features = false }
9897
ctor = "0.2.0"

datafusion/core/Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ rand = { workspace = true, features = ["small_rng"] }
146146
rand_distr = "0.4.3"
147147
regex = { workspace = true }
148148
rstest = { workspace = true }
149-
rust_decimal = { version = "1.27.0", features = ["tokio-pg"] }
150149
serde_json = { workspace = true }
151150
test-utils = { path = "../../test-utils" }
152151
tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] }

datafusion/sqllogictest/Cargo.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ path = "src/lib.rs"
3636
[dependencies]
3737
arrow = { workspace = true }
3838
async-trait = { workspace = true }
39-
bigdecimal = { workspace = true }
4039
bytes = { workspace = true, optional = true }
4140
chrono = { workspace = true, optional = true }
4241
clap = { version = "4.5.16", features = ["derive", "env"] }
@@ -50,7 +49,8 @@ log = { workspace = true }
5049
object_store = { workspace = true }
5150
postgres-protocol = { version = "0.6.4", optional = true }
5251
postgres-types = { version = "0.2.4", optional = true }
53-
rust_decimal = { version = "1.27.0" }
52+
rust_decimal = { version = "1.27.0", features = ["tokio-pg"], optional = true }
53+
ryu = "1.0.18"
5454
sqllogictest = "0.23.0"
5555
sqlparser = { workspace = true }
5656
tempfile = { workspace = true }
@@ -66,6 +66,7 @@ postgres = [
6666
"tokio-postgres",
6767
"postgres-types",
6868
"postgres-protocol",
69+
"rust_decimal",
6970
]
7071

7172
[dev-dependencies]

datafusion/sqllogictest/src/engines/conversion.rs

+17-90
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
// under the License.
1717

1818
use arrow::datatypes::{i256, Decimal128Type, Decimal256Type, DecimalType};
19-
use bigdecimal::BigDecimal;
2019
use half::f16;
21-
use rust_decimal::prelude::*;
2220

2321
/// Represents a constant for NULL string in your database.
2422
pub const NULL_STR: &str = "NULL";
@@ -40,17 +38,7 @@ pub(crate) fn varchar_to_str(value: &str) -> String {
4038
}
4139

4240
pub(crate) fn f16_to_str(value: f16) -> String {
43-
if value.is_nan() {
44-
// The sign of NaN can be different depending on platform.
45-
// So the string representation of NaN ignores the sign.
46-
"NaN".to_string()
47-
} else if value == f16::INFINITY {
48-
"Infinity".to_string()
49-
} else if value == f16::NEG_INFINITY {
50-
"-Infinity".to_string()
51-
} else {
52-
big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap())
53-
}
41+
f32_to_str(value.to_f32())
5442
}
5543

5644
pub(crate) fn f32_to_str(value: f32) -> String {
@@ -63,7 +51,7 @@ pub(crate) fn f32_to_str(value: f32) -> String {
6351
} else if value == f32::NEG_INFINITY {
6452
"-Infinity".to_string()
6553
} else {
66-
big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap())
54+
trim_decimal_trailing_zeros(ryu::Buffer::new().format(value).to_string())
6755
}
6856
}
6957

@@ -77,94 +65,33 @@ pub(crate) fn f64_to_str(value: f64) -> String {
7765
} else if value == f64::NEG_INFINITY {
7866
"-Infinity".to_string()
7967
} else {
80-
big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap())
68+
trim_decimal_trailing_zeros(ryu::Buffer::new().format(value).to_string())
8169
}
8270
}
8371

8472
pub(crate) fn i128_to_str(value: i128, precision: &u8, scale: &i8) -> String {
85-
big_decimal_to_str(
86-
BigDecimal::from_str(&Decimal128Type::format_decimal(value, *precision, *scale))
87-
.unwrap(),
88-
)
73+
trim_decimal_trailing_zeros(Decimal128Type::format_decimal(value, *precision, *scale))
8974
}
9075

9176
pub(crate) fn i256_to_str(value: i256, precision: &u8, scale: &i8) -> String {
92-
big_decimal_to_str(
93-
BigDecimal::from_str(&Decimal256Type::format_decimal(value, *precision, *scale))
94-
.unwrap(),
95-
)
77+
trim_decimal_trailing_zeros(Decimal256Type::format_decimal(value, *precision, *scale))
9678
}
9779

9880
#[cfg(feature = "postgres")]
99-
pub(crate) fn decimal_to_str(value: Decimal) -> String {
100-
big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap())
101-
}
102-
103-
pub(crate) fn big_decimal_to_str(value: BigDecimal) -> String {
104-
// Round the value to limit the number of decimal places
105-
let value = value.round(12).normalized();
106-
// Format the value to a string
107-
format_big_decimal(value)
81+
pub(crate) fn decimal_to_str(value: rust_decimal::Decimal) -> String {
82+
trim_decimal_trailing_zeros(value.to_string())
10883
}
10984

110-
fn format_big_decimal(value: BigDecimal) -> String {
111-
let (integer, scale) = value.into_bigint_and_exponent();
112-
let mut str = integer.to_str_radix(10);
113-
if scale <= 0 {
114-
// Append zeros to the right of the integer part
115-
str.extend(std::iter::repeat('0').take(scale.unsigned_abs() as usize));
116-
str
117-
} else {
118-
let (sign, unsigned_len, unsigned_str) = if integer.is_negative() {
119-
("-", str.len() - 1, &str[1..])
120-
} else {
121-
("", str.len(), &str[..])
122-
};
123-
let scale = scale as usize;
124-
if unsigned_len <= scale {
125-
format!("{}0.{:0>scale$}", sign, unsigned_str)
126-
} else {
127-
str.insert(str.len() - scale, '.');
128-
str
85+
fn trim_decimal_trailing_zeros(mut string: String) -> String {
86+
// Remove trailing zeros after the decimal point
87+
if let Some(decimal_idx) = string.find('.') {
88+
let after_decimal_idx = decimal_idx + 1;
89+
let after = &mut string[after_decimal_idx..];
90+
let trimmed_len = after.trim_end_matches('0').len();
91+
string.truncate(after_decimal_idx + trimmed_len);
92+
if string.ends_with('.') {
93+
string.pop();
12994
}
13095
}
131-
}
132-
133-
#[cfg(test)]
134-
mod tests {
135-
use super::big_decimal_to_str;
136-
use bigdecimal::{num_bigint::BigInt, BigDecimal};
137-
138-
macro_rules! assert_decimal_str_eq {
139-
($integer:expr, $scale:expr, $expected:expr) => {
140-
assert_eq!(
141-
big_decimal_to_str(BigDecimal::from_bigint(
142-
BigInt::from($integer),
143-
$scale
144-
)),
145-
$expected
146-
);
147-
};
148-
}
149-
150-
#[test]
151-
fn test_big_decimal_to_str() {
152-
assert_decimal_str_eq!(11, 3, "0.011");
153-
assert_decimal_str_eq!(11, 2, "0.11");
154-
assert_decimal_str_eq!(11, 1, "1.1");
155-
assert_decimal_str_eq!(11, 0, "11");
156-
assert_decimal_str_eq!(11, -1, "110");
157-
assert_decimal_str_eq!(0, 0, "0");
158-
159-
// Negative cases
160-
assert_decimal_str_eq!(-11, 3, "-0.011");
161-
assert_decimal_str_eq!(-11, 2, "-0.11");
162-
assert_decimal_str_eq!(-11, 1, "-1.1");
163-
assert_decimal_str_eq!(-11, 0, "-11");
164-
assert_decimal_str_eq!(-11, -1, "-110");
165-
166-
// Round to 12 decimal places
167-
// 1.0000000000011 -> 1.000000000001
168-
assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, "1.000000000001");
169-
}
96+
string
17097
}

0 commit comments

Comments
 (0)