diff --git a/Cargo.lock b/Cargo.lock index 28f9d33..1701b8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1358,6 +1358,7 @@ dependencies = [ "datafusion", "futures", "pgwire", + "rust_decimal", ] [[package]] diff --git a/datafusion-postgres/Cargo.toml b/datafusion-postgres/Cargo.toml index 1a4ae7d..c6306cf 100644 --- a/datafusion-postgres/Cargo.toml +++ b/datafusion-postgres/Cargo.toml @@ -21,3 +21,4 @@ datafusion = { workspace = true } futures = "0.3" async-trait = "0.1" chrono = { version = "0.4", features = ["std"] } +rust_decimal = { version = "1.35", features = ["db-postgres"] } diff --git a/datafusion-postgres/src/datatypes.rs b/datafusion-postgres/src/datatypes.rs index 7110296..3660d40 100644 --- a/datafusion-postgres/src/datatypes.rs +++ b/datafusion-postgres/src/datatypes.rs @@ -15,6 +15,8 @@ use pgwire::api::portal::{Format, Portal}; use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse}; use pgwire::api::Type; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use rust_decimal::prelude::ToPrimitive; +use rust_decimal::{Decimal, Error}; use timezone::Tz; pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult { @@ -38,6 +40,7 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult { DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary => Type::BYTEA, DataType::Float16 | DataType::Float32 => Type::FLOAT4, DataType::Float64 => Type::FLOAT8, + DataType::Decimal128(_, _) => Type::NUMERIC, DataType::Utf8 => Type::VARCHAR, DataType::LargeUtf8 => Type::TEXT, DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => { @@ -83,6 +86,24 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult { }) } +fn get_numeric_128_value(arr: &Arc, idx: usize, scale: u32) -> PgWireResult { + let array = arr.as_any().downcast_ref::().unwrap(); + let value = array.value(idx); + Decimal::try_from_i128_with_scale(value, scale).map_err(|e| { + let message = match e { + Error::ExceedsMaximumPossibleValue => "Exceeds maximum possible value", + Error::LessThanMinimumPossibleValue => "Less than minimum possible value", + Error::ScaleExceedsMaximumPrecision(_) => "Scale exceeds maximum precision", + _ => unreachable!(), + }; + PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + message.to_owned(), + ))) + }) +} + fn get_bool_value(arr: &Arc, idx: usize) -> bool { arr.as_any() .downcast_ref::() @@ -258,6 +279,9 @@ fn encode_value( DataType::UInt64 => encoder.encode_field(&(get_u64_value(arr, idx) as i64))?, DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx))?, DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx))?, + DataType::Decimal128(_, s) => { + encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32)?)? + } DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx))?, DataType::Utf8View => encoder.encode_field(&get_utf8_view_value(arr, idx))?, DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx))?, @@ -361,6 +385,17 @@ fn encode_value( DataType::UInt64 => encoder.encode_field(&get_u64_list_value(arr, idx))?, DataType::Float32 => encoder.encode_field(&get_f32_list_value(arr, idx))?, DataType::Float64 => encoder.encode_field(&get_f64_list_value(arr, idx))?, + DataType::Decimal128(_, s) => { + let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); + let value: Vec<_> = list_arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32))) + .collect(); + encoder.encode_field(&value)? + } DataType::Utf8 => { let list_arr = arr.as_any().downcast_ref::().unwrap().value(idx); let value: Vec<_> = list_arr @@ -711,9 +746,9 @@ pub(crate) async fn encode_dataframe<'a>( for col in 0..cols { let array = rb.column(col); if array.is_null(row) { - encoder.encode_field(&None::).unwrap(); + encoder.encode_field(&None::)?; } else { - encode_value(&mut encoder, array, row).unwrap(); + encode_value(&mut encoder, array, row)? } } encoder.finish() @@ -808,6 +843,20 @@ where let value = portal.parameter::(i, &pg_type)?; deserialized_params.push(ScalarValue::Float64(value)); } + Type::NUMERIC => { + let value = match portal.parameter::(i, &pg_type)? { + None => ScalarValue::Decimal128(None, 0, 0), + Some(value) => { + let precision = match value.mantissa() { + 0 => 1, + m => (m.abs() as f64).log10().floor() as u8 + 1, + }; + let scale = value.scale() as i8; + ScalarValue::Decimal128(value.to_i128(), precision, scale) + } + }; + deserialized_params.push(value); + } Type::TIMESTAMP => { let value = portal.parameter::(i, &pg_type)?; deserialized_params.push(ScalarValue::TimestampMicrosecond(