Skip to content

Commit 283cd4b

Browse files
committed
Improve error handling
1 parent 1ffc805 commit 283cd4b

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

datafusion-postgres/src/datatypes.rs

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse};
1616
use pgwire::api::Type;
1717
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
1818
use rust_decimal::prelude::ToPrimitive;
19-
use rust_decimal::Decimal;
19+
use rust_decimal::{Decimal, Error};
2020
use timezone::Tz;
2121

2222
pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
@@ -86,9 +86,22 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
8686
})
8787
}
8888

89-
fn get_numeric_128_value(arr: &Arc<dyn Array>, idx: usize, scale: u32) -> Decimal {
89+
fn get_numeric_128_value(arr: &Arc<dyn Array>, idx: usize, scale: u32) -> PgWireResult<Decimal> {
9090
let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
91-
Decimal::from_i128_with_scale(array.value(idx), scale)
91+
let value = array.value(idx);
92+
Decimal::try_from_i128_with_scale(value, scale).map_err(|e| {
93+
let message = match e {
94+
Error::ExceedsMaximumPossibleValue => "Exceeds maximum possible value",
95+
Error::LessThanMinimumPossibleValue => "Less than minimum possible value",
96+
Error::ScaleExceedsMaximumPrecision(_) => "Scale exceeds maximum precision",
97+
_ => unreachable!(),
98+
};
99+
PgWireError::UserError(Box::new(ErrorInfo::new(
100+
"ERROR".to_owned(),
101+
"XX000".to_owned(),
102+
message.to_owned(),
103+
)))
104+
})
92105
}
93106

94107
fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> bool {
@@ -267,7 +280,7 @@ fn encode_value(
267280
DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx))?,
268281
DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx))?,
269282
DataType::Decimal128(_, s) => {
270-
encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32))?
283+
encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32)?)?
271284
}
272285
DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx))?,
273286
DataType::Utf8View => encoder.encode_field(&get_utf8_view_value(arr, idx))?,
@@ -379,7 +392,7 @@ fn encode_value(
379392
.downcast_ref::<Decimal128Array>()
380393
.unwrap()
381394
.iter()
382-
.map(|v| Decimal::from_i128_with_scale(v.unwrap(), *s as u32))
395+
.map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32)))
383396
.collect();
384397
encoder.encode_field(&value)?
385398
}
@@ -733,9 +746,9 @@ pub(crate) async fn encode_dataframe<'a>(
733746
for col in 0..cols {
734747
let array = rb.column(col);
735748
if array.is_null(row) {
736-
encoder.encode_field(&None::<i8>).unwrap();
749+
encoder.encode_field(&None::<i8>)?;
737750
} else {
738-
encode_value(&mut encoder, array, row).unwrap();
751+
encode_value(&mut encoder, array, row)?
739752
}
740753
}
741754
encoder.finish()
@@ -834,12 +847,9 @@ where
834847
let value = match portal.parameter::<Decimal>(i, &pg_type)? {
835848
None => ScalarValue::Decimal128(None, 0, 0),
836849
Some(value) => {
837-
let mantissa = value.mantissa();
838-
// Count digits in the mantissa
839-
let precision = if mantissa == 0 {
840-
1
841-
} else {
842-
(mantissa.abs() as f64).log10().floor() as u8 + 1
850+
let precision = match value.mantissa() {
851+
0 => 1,
852+
m => (m.abs() as f64).log10().floor() as u8 + 1,
843853
};
844854
let scale = value.scale() as i8;
845855
ScalarValue::Decimal128(value.to_i128(), precision, scale)

0 commit comments

Comments
 (0)