From 315fbcf4df0e56a2134a85f912e78405651dc6e2 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sat, 1 Nov 2025 08:46:41 +0800 Subject: [PATCH 1/5] feat: update type conversion to use field instead of datatype --- arrow-pg/src/datatypes.rs | 91 ++++++++++++++++------------- arrow-pg/src/datatypes/df.rs | 4 +- datafusion-postgres/src/handlers.rs | 3 +- 3 files changed, 55 insertions(+), 43 deletions(-) diff --git a/arrow-pg/src/datatypes.rs b/arrow-pg/src/datatypes.rs index c3c6276..ac1a45b 100644 --- a/arrow-pg/src/datatypes.rs +++ b/arrow-pg/src/datatypes.rs @@ -17,7 +17,9 @@ use crate::row_encoder::RowEncoder; #[cfg(feature = "datafusion")] pub mod df; -pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { +pub fn into_pg_type(field: &Arc) -> PgWireResult { + let arrow_type = field.data_type(); + Ok(match arrow_type { DataType::Null => Type::UNKNOWN, DataType::Boolean => Type::BOOL, @@ -43,46 +45,55 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { DataType::Float64 => Type::FLOAT8, DataType::Decimal128(_, _) => Type::NUMERIC, DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT, - DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => { - match field.data_type() { - DataType::Boolean => Type::BOOL_ARRAY, - DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY, - DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY, - DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY, - DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY, - DataType::Timestamp(_, tz) => { - if tz.is_some() { - Type::TIMESTAMPTZ_ARRAY - } else { - Type::TIMESTAMP_ARRAY - } - } - DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY, - DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY, - DataType::Interval(_) => Type::INTERVAL_ARRAY, - DataType::FixedSizeBinary(_) - | DataType::Binary - | DataType::LargeBinary - | DataType::BinaryView => Type::BYTEA_ARRAY, - DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY, - DataType::Float64 => Type::FLOAT8_ARRAY, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY, - struct_type @ DataType::Struct(_) => Type::new( - Type::RECORD_ARRAY.name().into(), - Type::RECORD_ARRAY.oid(), - Kind::Array(into_pg_type(struct_type)?), - Type::RECORD_ARRAY.schema().into(), - ), - list_type => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!("Unsupported List Datatype {list_type}"), - )))); + DataType::List(field) + | DataType::FixedSizeList(field, _) + | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) => match field.data_type() { + DataType::Boolean => Type::BOOL_ARRAY, + DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY, + DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY, + DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY, + DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY, + DataType::Timestamp(_, tz) => { + if tz.is_some() { + Type::TIMESTAMPTZ_ARRAY + } else { + Type::TIMESTAMP_ARRAY } } + DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY, + DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY, + DataType::Interval(_) => Type::INTERVAL_ARRAY, + DataType::FixedSizeBinary(_) + | DataType::Binary + | DataType::LargeBinary + | DataType::BinaryView => Type::BYTEA_ARRAY, + DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY, + DataType::Float64 => Type::FLOAT8_ARRAY, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY, + DataType::Struct(_) => Type::new( + Type::RECORD_ARRAY.name().into(), + Type::RECORD_ARRAY.oid(), + Kind::Array(into_pg_type(field)?), + Type::RECORD_ARRAY.schema().into(), + ), + list_type => { + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!("Unsupported List Datatype {list_type}"), + )))); + } + }, + DataType::Dictionary(_, value_type) => { + let field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + *value_type.clone(), + true, + )); + into_pg_type(&field)? } - DataType::Dictionary(_, value_type) => into_pg_type(value_type)?, DataType::Struct(fields) => { let name: String = fields .iter() @@ -94,7 +105,7 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { fields .iter() .map(|x| { - into_pg_type(x.data_type()) + into_pg_type(x) .map(|_type| postgres_types::Field::new(x.name().clone(), _type)) }) .collect::, PgWireError>>()?, @@ -117,7 +128,7 @@ pub fn arrow_schema_to_pg_fields(schema: &Schema, format: &Format) -> PgWireResu .iter() .enumerate() .map(|(idx, f)| { - let pg_type = into_pg_type(f.data_type())?; + let pg_type = into_pg_type(f)?; Ok(FieldInfo::new( f.name().into(), None, diff --git a/arrow-pg/src/datatypes/df.rs b/arrow-pg/src/datatypes/df.rs index c81d53a..de343b2 100644 --- a/arrow-pg/src/datatypes/df.rs +++ b/arrow-pg/src/datatypes/df.rs @@ -2,7 +2,7 @@ use std::iter; use std::sync::Arc; use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; -use datafusion::arrow::datatypes::{DataType, Date32Type}; +use datafusion::arrow::datatypes::{DataType, Date32Type, Field}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::ParamValues; use datafusion::prelude::*; @@ -61,7 +61,7 @@ where if let Some(ty) = pg_type_hint { Ok(ty.clone()) } else if let Some(infer_type) = inferenced_type { - into_pg_type(infer_type) + into_pg_type(&Arc::new(Field::new("item", infer_type.clone(), true))) } else { Ok(Type::UNKNOWN) } diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 55fafd2..b6069f3 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -391,7 +391,8 @@ impl ExtendedQueryHandler for DfSessionService { for param_type in ordered_param_types(¶ms).iter() { // Fixed: Use ¶ms if let Some(datatype) = param_type { - let pgtype = into_pg_type(datatype)?; + let pgtype = + into_pg_type(&Arc::new(Field::new("item", (*datatype).clone(), true)))?; param_types.push(pgtype); } else { param_types.push(Type::UNKNOWN); From 616cde4560cba7157753f40c1cd9a523d70e69ab Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 2 Nov 2025 03:06:30 +0800 Subject: [PATCH 2/5] feat: pass arrow field all the way the any encoder --- Cargo.lock | 75 +++++++++++++- Cargo.toml | 1 + arrow-pg/Cargo.toml | 5 +- arrow-pg/src/datatypes.rs | 183 +++++++++++++++++---------------- arrow-pg/src/encoder.rs | 28 +++-- arrow-pg/src/list_encoder.rs | 4 +- arrow-pg/src/row_encoder.rs | 9 +- arrow-pg/src/struct_encoder.rs | 10 +- 8 files changed, 206 insertions(+), 109 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bf7d015..f1cc3bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -145,12 +145,21 @@ dependencies = [ "snap", "strum 0.27.2", "strum_macros 0.27.2", - "thiserror", + "thiserror 2.0.17", "uuid", "xz2", "zstd", ] +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "array-init" version = "2.1.0" @@ -336,12 +345,14 @@ name = "arrow-pg" version = "0.8.1" dependencies = [ "arrow", + "arrow-schema", "async-trait", "bytes", "chrono", "datafusion", "duckdb", "futures", + "geoarrow-schema", "pgwire", "postgres-types", "rust_decimal", @@ -1921,6 +1932,39 @@ dependencies = [ "version_check", ] +[[package]] +name = "geo-traits" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e7c353d12a704ccfab1ba8bfb1a7fe6cb18b665bf89d37f4f7890edcd260206" +dependencies = [ + "geo-types", +] + +[[package]] +name = "geo-types" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75a4dcd69d35b2c87a7c83bce9af69fd65c9d68d3833a0ded568983928f3fc99" +dependencies = [ + "approx", + "num-traits", + "serde", +] + +[[package]] +name = "geoarrow-schema" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02f1b18b1c9a44ecd72be02e53d6e63bbccfdc8d1765206226af227327e2be6e" +dependencies = [ + "arrow-schema", + "geo-traits", + "serde", + "serde_json", + "thiserror 1.0.69", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -2602,7 +2646,7 @@ dependencies = [ "itertools", "parking_lot", "percent-encoding", - "thiserror", + "thiserror 2.0.17", "tokio", "tracing", "url", @@ -2749,7 +2793,7 @@ dependencies = [ "serde", "serde_json", "stringprep", - "thiserror", + "thiserror 2.0.17", "tokio", "tokio-rustls", "tokio-util", @@ -2835,6 +2879,7 @@ dependencies = [ "bytes", "chrono", "fallible-iterator 0.2.0", + "geo-types", "postgres-protocol", "serde_core", "serde_json", @@ -3629,13 +3674,33 @@ dependencies = [ "unicode-width 0.1.14", ] +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + [[package]] name = "thiserror" version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" dependencies = [ - "thiserror-impl", + "thiserror-impl 2.0.17", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] @@ -4321,7 +4386,7 @@ dependencies = [ "ring", "signature", "spki", - "thiserror", + "thiserror 2.0.17", "zeroize", ] diff --git a/Cargo.toml b/Cargo.toml index 10f37d2..8a6bc00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ documentation = "https://docs.rs/crate/datafusion-postgres/" [workspace.dependencies] arrow = "56" +arrow-schema = "56" bytes = "1.10.1" chrono = { version = "0.4", features = ["std"] } datafusion = { version = "50", default-features = false } diff --git a/arrow-pg/Cargo.toml b/arrow-pg/Cargo.toml index 92c60ca..2be4e76 100644 --- a/arrow-pg/Cargo.toml +++ b/arrow-pg/Cargo.toml @@ -13,9 +13,10 @@ readme = "../README.md" rust-version.workspace = true [features] -default = ["arrow"] +default = ["arrow", "geo"] arrow = ["dep:arrow"] datafusion = ["dep:datafusion"] +geo = ["postgres-types/with-geo-types-0_7", "dep:geoarrow-schema"] # for testing _duckdb = [] _bundled = ["duckdb/bundled"] @@ -23,6 +24,8 @@ _bundled = ["duckdb/bundled"] [dependencies] arrow = { workspace = true, optional = true } +arrow-schema = { workspace = true } +geoarrow-schema = { version = "0.6", optional = true } bytes.workspace = true chrono.workspace = true datafusion = { workspace = true, optional = true } diff --git a/arrow-pg/src/datatypes.rs b/arrow-pg/src/datatypes.rs index ac1a45b..9114d33 100644 --- a/arrow-pg/src/datatypes.rs +++ b/arrow-pg/src/datatypes.rs @@ -2,6 +2,7 @@ use std::sync::Arc; #[cfg(not(feature = "datafusion"))] use arrow::{datatypes::*, record_batch::RecordBatch}; +use arrow_schema::extension::ExtensionType; #[cfg(feature = "datafusion")] use datafusion::arrow::{datatypes::*, record_batch::RecordBatch}; @@ -20,106 +21,110 @@ pub mod df; pub fn into_pg_type(field: &Arc) -> PgWireResult { let arrow_type = field.data_type(); - Ok(match arrow_type { - DataType::Null => Type::UNKNOWN, - DataType::Boolean => Type::BOOL, - DataType::Int8 | DataType::UInt8 => Type::CHAR, - DataType::Int16 | DataType::UInt16 => Type::INT2, - DataType::Int32 | DataType::UInt32 => Type::INT4, - DataType::Int64 | DataType::UInt64 => Type::INT8, - DataType::Timestamp(_, tz) => { - if tz.is_some() { - Type::TIMESTAMPTZ - } else { - Type::TIMESTAMP - } - } - DataType::Time32(_) | DataType::Time64(_) => Type::TIME, - DataType::Date32 | DataType::Date64 => Type::DATE, - DataType::Interval(_) => Type::INTERVAL, - DataType::Binary - | DataType::FixedSizeBinary(_) - | DataType::LargeBinary - | DataType::BinaryView => Type::BYTEA, - DataType::Float16 | DataType::Float32 => Type::FLOAT4, - DataType::Float64 => Type::FLOAT8, - DataType::Decimal128(_, _) => Type::NUMERIC, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT, - DataType::List(field) - | DataType::FixedSizeList(field, _) - | DataType::LargeList(field) - | DataType::ListView(field) - | DataType::LargeListView(field) => match field.data_type() { - DataType::Boolean => Type::BOOL_ARRAY, - DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY, - DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY, - DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY, - DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY, + match field.extension_type_name() { + #[cfg(feature = "geo")] + Some(geoarrow_schema::PointType::NAME) => Ok(Type::POINT), + _ => Ok(match arrow_type { + DataType::Null => Type::UNKNOWN, + DataType::Boolean => Type::BOOL, + DataType::Int8 | DataType::UInt8 => Type::CHAR, + DataType::Int16 | DataType::UInt16 => Type::INT2, + DataType::Int32 | DataType::UInt32 => Type::INT4, + DataType::Int64 | DataType::UInt64 => Type::INT8, DataType::Timestamp(_, tz) => { if tz.is_some() { - Type::TIMESTAMPTZ_ARRAY + Type::TIMESTAMPTZ } else { - Type::TIMESTAMP_ARRAY + Type::TIMESTAMP } } - DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY, - DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY, - DataType::Interval(_) => Type::INTERVAL_ARRAY, - DataType::FixedSizeBinary(_) - | DataType::Binary + DataType::Time32(_) | DataType::Time64(_) => Type::TIME, + DataType::Date32 | DataType::Date64 => Type::DATE, + DataType::Interval(_) => Type::INTERVAL, + DataType::Binary + | DataType::FixedSizeBinary(_) | DataType::LargeBinary - | DataType::BinaryView => Type::BYTEA_ARRAY, - DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY, - DataType::Float64 => Type::FLOAT8_ARRAY, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY, - DataType::Struct(_) => Type::new( - Type::RECORD_ARRAY.name().into(), - Type::RECORD_ARRAY.oid(), - Kind::Array(into_pg_type(field)?), - Type::RECORD_ARRAY.schema().into(), - ), - list_type => { + | DataType::BinaryView => Type::BYTEA, + DataType::Float16 | DataType::Float32 => Type::FLOAT4, + DataType::Float64 => Type::FLOAT8, + DataType::Decimal128(_, _) => Type::NUMERIC, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT, + DataType::List(field) + | DataType::FixedSizeList(field, _) + | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) => match field.data_type() { + DataType::Boolean => Type::BOOL_ARRAY, + DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY, + DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY, + DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY, + DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY, + DataType::Timestamp(_, tz) => { + if tz.is_some() { + Type::TIMESTAMPTZ_ARRAY + } else { + Type::TIMESTAMP_ARRAY + } + } + DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY, + DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY, + DataType::Interval(_) => Type::INTERVAL_ARRAY, + DataType::FixedSizeBinary(_) + | DataType::Binary + | DataType::LargeBinary + | DataType::BinaryView => Type::BYTEA_ARRAY, + DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY, + DataType::Float64 => Type::FLOAT8_ARRAY, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY, + DataType::Struct(_) => Type::new( + Type::RECORD_ARRAY.name().into(), + Type::RECORD_ARRAY.oid(), + Kind::Array(into_pg_type(field)?), + Type::RECORD_ARRAY.schema().into(), + ), + list_type => { + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!("Unsupported List Datatype {list_type}"), + )))); + } + }, + DataType::Dictionary(_, value_type) => { + let field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + *value_type.clone(), + true, + )); + into_pg_type(&field)? + } + DataType::Struct(fields) => { + let name: String = fields + .iter() + .map(|x| x.name().clone()) + .reduce(|a, b| a + ", " + &b) + .map(|x| format!("({x})")) + .unwrap_or("()".to_string()); + let kind = Kind::Composite( + fields + .iter() + .map(|x| { + into_pg_type(x) + .map(|_type| postgres_types::Field::new(x.name().clone(), _type)) + }) + .collect::, PgWireError>>()?, + ); + Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into()) + } + _ => { return Err(PgWireError::UserError(Box::new(ErrorInfo::new( "ERROR".to_owned(), "XX000".to_owned(), - format!("Unsupported List Datatype {list_type}"), + format!("Unsupported Datatype {arrow_type}"), )))); } - }, - DataType::Dictionary(_, value_type) => { - let field = Arc::new(Field::new( - Field::LIST_FIELD_DEFAULT_NAME, - *value_type.clone(), - true, - )); - into_pg_type(&field)? - } - DataType::Struct(fields) => { - let name: String = fields - .iter() - .map(|x| x.name().clone()) - .reduce(|a, b| a + ", " + &b) - .map(|x| format!("({x})")) - .unwrap_or("()".to_string()); - let kind = Kind::Composite( - fields - .iter() - .map(|x| { - into_pg_type(x) - .map(|_type| postgres_types::Field::new(x.name().clone(), _type)) - }) - .collect::, PgWireError>>()?, - ); - Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into()) - } - _ => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!("Unsupported Datatype {arrow_type}"), - )))); - } - }) + }), + } } pub fn arrow_schema_to_pg_fields(schema: &Schema, format: &Format) -> PgWireResult> { diff --git a/arrow-pg/src/encoder.rs b/arrow-pg/src/encoder.rs index 074939c..65d7f51 100644 --- a/arrow-pg/src/encoder.rs +++ b/arrow-pg/src/encoder.rs @@ -12,6 +12,7 @@ use chrono::{NaiveDate, NaiveDateTime}; use datafusion::arrow::{array::*, datatypes::*}; use pgwire::api::results::DataRowEncoder; use pgwire::api::results::FieldFormat; +use pgwire::api::results::FieldInfo; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::types::ToSqlText; use postgres_types::{ToSql, Type}; @@ -288,9 +289,12 @@ pub fn encode_value( encoder: &mut T, arr: &Arc, idx: usize, - type_: &Type, - format: FieldFormat, + _arrow_filed: &Field, + pg_field: &FieldInfo, ) -> PgWireResult<()> { + let type_ = pg_field.datatype(); + let format = pg_field.format(); + match arr.data_type() { DataType::Null => encoder.encode_field_with_type_and_format(&None::, type_, format)?, DataType::Boolean => { @@ -494,7 +498,7 @@ pub fn encode_value( let value = encode_list(array, type_, format)?; encoder.encode_field_with_type_and_format(&value, type_, format)? } - DataType::Struct(_) => { + DataType::Struct(arrow_fields) => { let fields = match type_.kind() { postgres_types::Kind::Composite(fields) => fields, _ => { @@ -503,7 +507,7 @@ pub fn encode_value( )))); } }; - let value = encode_struct(arr, idx, fields, format)?; + let value = encode_struct(arr, idx, arrow_fields, fields, format)?; encoder.encode_field_with_type_and_format(&value, type_, format)? } DataType::Dictionary(_, value_type) => { @@ -534,7 +538,16 @@ pub fn encode_value( )) })?; - encode_value(encoder, values, idx, type_, format)? + let inner_pg_field = FieldInfo::new( + pg_field.name().to_string(), + None, + None, + type_.clone(), + format, + ); + let inner_arrow_field = Field::new(pg_field.name(), *value_type.clone(), true); + + encode_value(encoder, values, idx, &inner_arrow_field, &inner_pg_field)? } _ => { return Err(PgWireError::ApiError(ToSqlError::from(format!( @@ -585,7 +598,10 @@ mod tests { let mut encoder = MockEncoder::default(); - let result = encode_value(&mut encoder, &dict_arr, 2, &Type::TEXT, FieldFormat::Text); + let arrow_field = Field::new("x", DataType::Utf8, true); + let pg_field = FieldInfo::new("x".to_string(), None, None, Type::TEXT, FieldFormat::Text); + + let result = encode_value(&mut encoder, &dict_arr, 2, &arrow_field, &pg_field); assert!(result.is_ok()); diff --git a/arrow-pg/src/list_encoder.rs b/arrow-pg/src/list_encoder.rs index a13c1c7..ae9893b 100644 --- a/arrow-pg/src/list_encoder.rs +++ b/arrow-pg/src/list_encoder.rs @@ -386,7 +386,7 @@ pub(crate) fn encode_list( } } }, - DataType::Struct(_) => { + DataType::Struct(arrow_fields) => { let fields = match type_.kind() { postgres_types::Kind::Array(struct_type_) => Ok(struct_type_), _ => Err(format!( @@ -406,7 +406,7 @@ pub(crate) fn encode_list( .map_err(ToSqlError::from)?; let values: PgWireResult> = (0..arr.len()) - .map(|row| encode_struct(&arr, row, fields, format)) + .map(|row| encode_struct(&arr, row, arrow_fields, fields, format)) .map(|x| { if matches!(format, FieldFormat::Text) { x.map(|opt| { diff --git a/arrow-pg/src/row_encoder.rs b/arrow-pg/src/row_encoder.rs index 145c9ab..c8a73f6 100644 --- a/arrow-pg/src/row_encoder.rs +++ b/arrow-pg/src/row_encoder.rs @@ -33,13 +33,14 @@ impl RowEncoder { if self.curr_idx == self.rb.num_rows() { return None; } + let arrow_schema = self.rb.schema_ref(); let mut encoder = DataRowEncoder::new(self.fields.clone()); for col in 0..self.rb.num_columns() { let array = self.rb.column(col); - let field = &self.fields[col]; - let type_ = field.datatype(); - let format = field.format(); - encode_value(&mut encoder, array, self.curr_idx, type_, format).unwrap(); + let arrow_field = arrow_schema.field(col); + let pg_field = &self.fields[col]; + + encode_value(&mut encoder, array, self.curr_idx, arrow_field, pg_field).unwrap(); } self.curr_idx += 1; Some(encoder.finish()) diff --git a/arrow-pg/src/struct_encoder.rs b/arrow-pg/src/struct_encoder.rs index 49fce1b..1224797 100644 --- a/arrow-pg/src/struct_encoder.rs +++ b/arrow-pg/src/struct_encoder.rs @@ -2,11 +2,12 @@ use std::sync::Arc; #[cfg(not(feature = "datafusion"))] use arrow::array::{Array, StructArray}; +use arrow_schema::Fields; #[cfg(feature = "datafusion")] use datafusion::arrow::array::{Array, StructArray}; use bytes::{BufMut, BytesMut}; -use pgwire::api::results::FieldFormat; +use pgwire::api::results::{FieldFormat, FieldInfo}; use pgwire::error::PgWireResult; use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE}; use postgres_types::{Field, IsNull, ToSql, Type}; @@ -16,6 +17,7 @@ use crate::encoder::{encode_value, EncodedValue, Encoder}; pub(crate) fn encode_struct( arr: &Arc, idx: usize, + arrow_fields: &Fields, fields: &[Field], format: FieldFormat, ) -> PgWireResult> { @@ -27,7 +29,11 @@ pub(crate) fn encode_struct( for (i, arr) in arr.columns().iter().enumerate() { let field = &fields[i]; let type_ = field.type_(); - encode_value(&mut row_encoder, arr, idx, type_, format).unwrap(); + + let arrow_field = &arrow_fields[i]; + let pgwire_field = FieldInfo::new("fields".to_string(), None, None, type_.clone(), format); + + encode_value(&mut row_encoder, arr, idx, arrow_field, &pgwire_field).unwrap(); } Ok(Some(EncodedValue { bytes: row_encoder.row_buffer, From c19dc3d768054c42c485faecb512f4131c705241 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Thu, 6 Nov 2025 13:35:20 +0800 Subject: [PATCH 3/5] feat: update api after merge --- arrow-pg/src/encoder.rs | 12 +----------- arrow-pg/src/list_encoder.rs | 18 ------------------ arrow-pg/src/struct_encoder.rs | 21 ++++++++++++++++----- 3 files changed, 17 insertions(+), 34 deletions(-) diff --git a/arrow-pg/src/encoder.rs b/arrow-pg/src/encoder.rs index 6cf64ff..b7144a0 100644 --- a/arrow-pg/src/encoder.rs +++ b/arrow-pg/src/encoder.rs @@ -10,7 +10,7 @@ use bytes::BytesMut; use chrono::{NaiveDate, NaiveDateTime}; #[cfg(feature = "datafusion")] use datafusion::arrow::{array::*, datatypes::*}; -use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo}; +use pgwire::api::results::{DataRowEncoder, FieldInfo}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::types::format::FormatOptions; use pgwire::types::ToSqlText; @@ -287,8 +287,6 @@ pub fn encode_value( _arrow_filed: &Field, pg_field: &FieldInfo, ) -> PgWireResult<()> { - let type_ = pg_field.datatype(); - match arr.data_type() { DataType::Null => encoder.encode_field(&None::, pg_field)?, DataType::Boolean => encoder.encode_field(&get_bool_value(arr, idx), pg_field)?, @@ -424,14 +422,6 @@ pub fn encode_value( encoder.encode_field(&value, pg_field)? } DataType::Struct(arrow_fields) => { - let fields = match type_.kind() { - postgres_types::Kind::Composite(fields) => fields, - _ => { - return Err(PgWireError::ApiError(ToSqlError::from(format!( - "Failed to unwrap a composite type from type {type_}" - )))); - } - }; let value = encode_struct(arr, idx, arrow_fields, pg_field)?; encoder.encode_field(&value, pg_field)? } diff --git a/arrow-pg/src/list_encoder.rs b/arrow-pg/src/list_encoder.rs index 8c061cc..0da22fb 100644 --- a/arrow-pg/src/list_encoder.rs +++ b/arrow-pg/src/list_encoder.rs @@ -387,24 +387,6 @@ pub(crate) fn encode_list(arr: Arc, pg_field: &FieldInfo) -> PgWireRe } }, DataType::Struct(arrow_fields) => { - let fields = match type_.kind() { - postgres_types::Kind::Array(struct_type_) => Ok(struct_type_), - _ => Err(format!( - "Expected list type found type {} of kind {:?}", - type_, - type_.kind() - )), - } - .and_then(|struct_type| match struct_type.kind() { - postgres_types::Kind::Composite(fields) => Ok(fields), - _ => Err(format!( - "Failed to unwrap a composite type inside from type {} kind {:?}", - type_, - type_.kind() - )), - }) - .map_err(ToSqlError::from)?; - let values: PgWireResult> = (0..arr.len()) .map(|row| encode_struct(&arr, row, arrow_fields, pg_field)) .map(|x| { diff --git a/arrow-pg/src/struct_encoder.rs b/arrow-pg/src/struct_encoder.rs index 869fd69..952362e 100644 --- a/arrow-pg/src/struct_encoder.rs +++ b/arrow-pg/src/struct_encoder.rs @@ -8,29 +8,41 @@ use datafusion::arrow::array::{Array, StructArray}; use bytes::{BufMut, BytesMut}; use pgwire::api::results::{FieldFormat, FieldInfo}; -use pgwire::error::PgWireResult; +use pgwire::error::{PgWireError, PgWireResult}; use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE}; -use postgres_types::{Field, IsNull, ToSql}; +use postgres_types::{IsNull, ToSql}; use crate::encoder::{encode_value, EncodedValue, Encoder}; +use crate::error::ToSqlError; pub(crate) fn encode_struct( arr: &Arc, idx: usize, arrow_fields: &Fields, - fields: &[Field], parent_pg_field_info: &FieldInfo, ) -> PgWireResult> { let arr = arr.as_any().downcast_ref::().unwrap(); if arr.is_null(idx) { return Ok(None); } - let mut row_encoder = StructEncoder::new(fields.len()); + + let fields = match parent_pg_field_info.datatype().kind() { + postgres_types::Kind::Composite(fields) => fields, + _ => { + return Err(PgWireError::ApiError(ToSqlError::from(format!( + "Failed to unwrap a composite type of {}", + parent_pg_field_info.datatype() + )))); + } + }; + + let mut row_encoder = StructEncoder::new(arrow_fields.len()); for (i, arr) in arr.columns().iter().enumerate() { let field = &fields[i]; let type_ = field.type_(); let arrow_field = &arrow_fields[i]; + let mut pg_field = FieldInfo::new( field.name().to_string(), None, @@ -38,7 +50,6 @@ pub(crate) fn encode_struct( type_.clone(), parent_pg_field_info.format(), ); - pg_field = pg_field.with_format_options(parent_pg_field_info.format_options().clone()); encode_value(&mut row_encoder, arr, idx, arrow_field, &pg_field).unwrap(); From d6508d9d95aa80fb190a0a7a25e2730cf06cacfb Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Thu, 6 Nov 2025 14:05:19 +0800 Subject: [PATCH 4/5] fix: duckdb example --- arrow-pg/examples/duckdb.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/arrow-pg/examples/duckdb.rs b/arrow-pg/examples/duckdb.rs index 29faa1e..7298680 100644 --- a/arrow-pg/examples/duckdb.rs +++ b/arrow-pg/examples/duckdb.rs @@ -3,6 +3,7 @@ use std::sync::{Arc, Mutex}; use arrow_pg::datatypes::arrow_schema_to_pg_fields; use arrow_pg::datatypes::encode_recordbatch; use arrow_pg::datatypes::into_pg_type; +use arrow_schema::Field; use async_trait::async_trait; use duckdb::{params, Connection, Statement, ToSql}; use futures::stream; @@ -137,11 +138,13 @@ fn row_desc_from_stmt(stmt: &Statement, format: &Format) -> PgWireResult Date: Fri, 7 Nov 2025 04:37:11 +0800 Subject: [PATCH 5/5] fix: resolve type for list/struct --- arrow-pg/src/row_encoder.rs | 5 ++++- arrow-pg/src/struct_encoder.rs | 20 ++++++++------------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/arrow-pg/src/row_encoder.rs b/arrow-pg/src/row_encoder.rs index c8a73f6..f9b0a32 100644 --- a/arrow-pg/src/row_encoder.rs +++ b/arrow-pg/src/row_encoder.rs @@ -40,7 +40,10 @@ impl RowEncoder { let arrow_field = arrow_schema.field(col); let pg_field = &self.fields[col]; - encode_value(&mut encoder, array, self.curr_idx, arrow_field, pg_field).unwrap(); + if let Err(e) = encode_value(&mut encoder, array, self.curr_idx, arrow_field, pg_field) + { + return Some(Err(e)); + }; } self.curr_idx += 1; Some(encoder.finish()) diff --git a/arrow-pg/src/struct_encoder.rs b/arrow-pg/src/struct_encoder.rs index 952362e..7db119f 100644 --- a/arrow-pg/src/struct_encoder.rs +++ b/arrow-pg/src/struct_encoder.rs @@ -8,12 +8,12 @@ use datafusion::arrow::array::{Array, StructArray}; use bytes::{BufMut, BytesMut}; use pgwire::api::results::{FieldFormat, FieldInfo}; -use pgwire::error::{PgWireError, PgWireResult}; +use pgwire::error::PgWireResult; use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE}; -use postgres_types::{IsNull, ToSql}; +use postgres_types::{Field, IsNull, ToSql}; +use crate::datatypes::into_pg_type; use crate::encoder::{encode_value, EncodedValue, Encoder}; -use crate::error::ToSqlError; pub(crate) fn encode_struct( arr: &Arc, @@ -26,17 +26,13 @@ pub(crate) fn encode_struct( return Ok(None); } - let fields = match parent_pg_field_info.datatype().kind() { - postgres_types::Kind::Composite(fields) => fields, - _ => { - return Err(PgWireError::ApiError(ToSqlError::from(format!( - "Failed to unwrap a composite type of {}", - parent_pg_field_info.datatype() - )))); - } - }; + let fields = arrow_fields + .iter() + .map(|f| into_pg_type(f).map(|t| Field::new(f.name().to_owned(), t))) + .collect::>>()?; let mut row_encoder = StructEncoder::new(arrow_fields.len()); + for (i, arr) in arr.columns().iter().enumerate() { let field = &fields[i]; let type_ = field.type_();