Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ datafusion = { workspace = true }
futures = "0.3"
async-trait = "0.1"
chrono = { version = "0.4", features = ["std"] }
rust_decimal = { version = "1.35", features = ["db-postgres"] }
39 changes: 39 additions & 0 deletions datafusion-postgres/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use pgwire::api::portal::{Format, Portal};
use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse};
use pgwire::api::Type;
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use rust_decimal::prelude::ToPrimitive;
use rust_decimal::Decimal;
use timezone::Tz;

pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
Expand All @@ -38,6 +40,7 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary => Type::BYTEA,
DataType::Float16 | DataType::Float32 => Type::FLOAT4,
DataType::Float64 => Type::FLOAT8,
DataType::Decimal128(_, _) => Type::NUMERIC,
DataType::Utf8 => Type::VARCHAR,
DataType::LargeUtf8 => Type::TEXT,
DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
Expand Down Expand Up @@ -83,6 +86,11 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
})
}

fn get_numeric_128_value(arr: &Arc<dyn Array>, idx: usize, scale: u32) -> Decimal {
let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
Decimal::from_i128_with_scale(array.value(idx), scale)
}

fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> bool {
arr.as_any()
.downcast_ref::<BooleanArray>()
Expand Down Expand Up @@ -258,6 +266,9 @@ fn encode_value(
DataType::UInt64 => encoder.encode_field(&(get_u64_value(arr, idx) as i64))?,
DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx))?,
DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx))?,
DataType::Decimal128(_, s) => {
encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32))?
}
DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx))?,
DataType::Utf8View => encoder.encode_field(&get_utf8_view_value(arr, idx))?,
DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx))?,
Expand Down Expand Up @@ -361,6 +372,17 @@ fn encode_value(
DataType::UInt64 => encoder.encode_field(&get_u64_list_value(arr, idx))?,
DataType::Float32 => encoder.encode_field(&get_f32_list_value(arr, idx))?,
DataType::Float64 => encoder.encode_field(&get_f64_list_value(arr, idx))?,
DataType::Decimal128(_, s) => {
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
let value: Vec<_> = list_arr
.as_any()
.downcast_ref::<Decimal128Array>()
.unwrap()
.iter()
.map(|v| Decimal::from_i128_with_scale(v.unwrap(), *s as u32))
.collect();
encoder.encode_field(&value)?
}
DataType::Utf8 => {
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
let value: Vec<_> = list_arr
Expand Down Expand Up @@ -808,6 +830,23 @@ where
let value = portal.parameter::<f64>(i, &pg_type)?;
deserialized_params.push(ScalarValue::Float64(value));
}
Type::NUMERIC => {
let value = match portal.parameter::<Decimal>(i, &pg_type)? {
None => ScalarValue::Decimal128(None, 0, 0),
Some(value) => {
let mantissa = value.mantissa();
// Count digits in the mantissa
let precision = if mantissa == 0 {
1
} else {
(mantissa.abs() as f64).log10().floor() as u8 + 1
};
let scale = value.scale() as i8;
ScalarValue::Decimal128(value.to_i128(), precision, scale)
}
};
deserialized_params.push(value);
}
Type::TIMESTAMP => {
let value = portal.parameter::<NaiveDateTime>(i, &pg_type)?;
deserialized_params.push(ScalarValue::TimestampMicrosecond(
Expand Down
Loading