From 933c951c49571d3953f36a69c95d14cb2e126710 Mon Sep 17 00:00:00 2001 From: nathaniel-d-ef Date: Tue, 16 Sep 2025 02:30:53 +0200 Subject: [PATCH 01/10] Expand benchmark coverage and tests for arrow-avro writer to include additional types like Decimal, FixedSizeBinary, Utf8, List, Struct, and Map. Add round-trip validation for complex and logical types including Duration and UUID. --- arrow-avro/benches/avro_writer.rs | 456 +++++++++++++++++++++++++++++- arrow-avro/src/writer/encoder.rs | 300 ++++++++++++++++++++ arrow-avro/src/writer/mod.rs | 218 ++++++++++++++ 3 files changed, 967 insertions(+), 7 deletions(-) diff --git a/arrow-avro/benches/avro_writer.rs b/arrow-avro/benches/avro_writer.rs index 924cbbdc84bd..aeb9edbac82a 100644 --- a/arrow-avro/benches/avro_writer.rs +++ b/arrow-avro/benches/avro_writer.rs @@ -15,19 +15,22 @@ // specific language governing permissions and limitations // under the License. -//! Benchmarks for `arrow‑avro` **Writer** (Avro Object Container Files) -//! +//! Benchmarks for `arrow-avro` Writer (Avro Object Container File) extern crate arrow_avro; extern crate criterion; extern crate once_cell; use arrow_array::{ - types::{Int32Type, Int64Type, TimestampMicrosecondType}, - ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, PrimitiveArray, RecordBatch, + builder::{ListBuilder, StringBuilder}, + types::{Int32Type, Int64Type, IntervalMonthDayNanoType, TimestampMicrosecondType}, + ArrayRef, BinaryArray, BooleanArray, Decimal128Array, Decimal256Array, Decimal32Array, + Decimal64Array, FixedSizeBinaryArray, Float32Array, Float64Array, ListArray, PrimitiveArray, + RecordBatch, StringArray, StructArray, }; use arrow_avro::writer::AvroWriter; -use arrow_schema::{DataType, Field, Schema, TimeUnit}; +use arrow_buffer::i256; +use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; use once_cell::sync::Lazy; use rand::{ @@ -35,6 +38,7 @@ use rand::{ rngs::StdRng, Rng, SeedableRng, }; +use std::collections::HashMap; use std::io::Cursor; use std::sync::Arc; use std::time::Duration; @@ -63,7 +67,9 @@ where #[inline] fn make_bool_array_with_tag(n: usize, tag: u64) -> BooleanArray { let mut rng = rng_for(tag, n); + // Can't use SampleUniform for bool; use the RNG's boolean helper let values = (0..n).map(|_| rng.random_bool(0.5)); + // This repo exposes `from_iter`, not `from_iter_values` for BooleanArray BooleanArray::from_iter(values.map(Some)) } @@ -81,6 +87,21 @@ fn make_i64_array_with_tag(n: usize, tag: u64) -> PrimitiveArray { PrimitiveArray::::from_iter_values(values) } +#[inline] +fn rand_ascii_string(rng: &mut StdRng, min_len: usize, max_len: usize) -> String { + let len = rng.random_range(min_len..=max_len); + (0..len) + .map(|_| (rng.random_range(b'a'..=b'z') as char)) + .collect() +} + +#[inline] +fn make_utf8_array_with_tag(n: usize, tag: u64) -> StringArray { + let mut rng = rng_for(tag, n); + let data: Vec = (0..n).map(|_| rand_ascii_string(&mut rng, 3, 16)).collect(); + StringArray::from_iter_values(data) +} + #[inline] fn make_f32_array_with_tag(n: usize, tag: u64) -> Float32Array { let mut rng = rng_for(tag, n); @@ -98,14 +119,52 @@ fn make_f64_array_with_tag(n: usize, tag: u64) -> Float64Array { #[inline] fn make_binary_array_with_tag(n: usize, tag: u64) -> BinaryArray { let mut rng = rng_for(tag, n); - let mut payloads: Vec<[u8; 16]> = vec![[0; 16]; n]; - for p in payloads.iter_mut() { + let mut payloads: Vec> = Vec::with_capacity(n); + for _ in 0..n { + let len = rng.random_range(1..=16); + let mut p = vec![0u8; len]; rng.fill(&mut p[..]); + payloads.push(p); } let views: Vec<&[u8]> = payloads.iter().map(|p| &p[..]).collect(); + // This repo exposes a simple `from_vec` for BinaryArray BinaryArray::from_vec(views) } +#[inline] +fn make_fixed16_array_with_tag(n: usize, tag: u64) -> FixedSizeBinaryArray { + let mut rng = rng_for(tag, n); + let payloads = (0..n) + .map(|_| { + let mut b = [0u8; 16]; + rng.fill(&mut b); + b + }) + .collect::>(); + // Fixed-size constructor available in this repo + FixedSizeBinaryArray::try_from_iter(payloads.into_iter()).expect("build FixedSizeBinaryArray") +} + +/// Make an Arrow `Interval(IntervalUnit::MonthDayNano)` array with **non-negative** +/// (months, days, nanos) values, and nanos as **multiples of 1_000_000** (whole ms), +/// per Avro `duration` constraints used by the writer. +#[inline] +fn make_interval_mdn_array_with_tag( + n: usize, + tag: u64, +) -> PrimitiveArray { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| { + let months: i32 = rng.random_range(0..=120); + let days: i32 = rng.random_range(0..=31); + // pick millis within a day (safe within u32::MAX and realistic) + let millis: u32 = rng.random_range(0..=86_400_000); + let nanos: i64 = (millis as i64) * 1_000_000; + IntervalMonthDayNanoType::make_value(months, days, nanos) + }); + PrimitiveArray::::from_iter_values(values) +} + #[inline] fn make_ts_micros_array_with_tag(n: usize, tag: u64) -> PrimitiveArray { let mut rng = rng_for(tag, n); @@ -115,6 +174,77 @@ fn make_ts_micros_array_with_tag(n: usize, tag: u64) -> PrimitiveArray::from_iter_values(values) } +// === Decimal helpers & generators === + +#[inline] +fn pow10_i32(p: u8) -> i32 { + (0..p).fold(1i32, |acc, _| acc.saturating_mul(10)) +} + +#[inline] +fn pow10_i64(p: u8) -> i64 { + (0..p).fold(1i64, |acc, _| acc.saturating_mul(10)) +} + +#[inline] +fn pow10_i128(p: u8) -> i128 { + (0..p).fold(1i128, |acc, _| acc.saturating_mul(10)) +} + +#[inline] +fn make_decimal32_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8) -> Decimal32Array { + let mut rng = rng_for(tag, n); + let max = pow10_i32(precision).saturating_sub(1); + let values = (0..n).map(|_| rng.random_range(-max..=max)); + Decimal32Array::from_iter_values(values) + .with_precision_and_scale(precision, scale) + .expect("set precision/scale on Decimal32Array") +} + +#[inline] +fn make_decimal64_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8) -> Decimal64Array { + let mut rng = rng_for(tag, n); + let max = pow10_i64(precision).saturating_sub(1); + let values = (0..n).map(|_| rng.random_range(-max..=max)); + Decimal64Array::from_iter_values(values) + .with_precision_and_scale(precision, scale) + .expect("set precision/scale on Decimal64Array") +} + +#[inline] +fn make_decimal128_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8) -> Decimal128Array { + let mut rng = rng_for(tag, n); + let max = pow10_i128(precision).saturating_sub(1); + let values = (0..n).map(|_| rng.random_range(-max..=max)); + Decimal128Array::from_iter_values(values) + .with_precision_and_scale(precision, scale) + .expect("set precision/scale on Decimal128Array") +} + +#[inline] +fn make_decimal256_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8) -> Decimal256Array { + // Generate within i128 range and widen to i256 to keep generation cheap and portable + let mut rng = rng_for(tag, n); + let max128 = pow10_i128(30).saturating_sub(1); + let values = (0..n).map(|_| { + let v: i128 = rng.random_range(-max128..=max128); + i256::from_i128(v) + }); + Decimal256Array::from_iter_values(values) + .with_precision_and_scale(precision, scale) + .expect("set precision/scale on Decimal256Array") +} + +#[inline] +fn make_fixed16_array(n: usize) -> FixedSizeBinaryArray { + make_fixed16_array_with_tag(n, 0xF15E_D016) +} + +#[inline] +fn make_interval_mdn_array(n: usize) -> PrimitiveArray { + make_interval_mdn_array_with_tag(n, 0xD0_1E_AD) +} + #[inline] fn make_bool_array(n: usize) -> BooleanArray { make_bool_array_with_tag(n, 0xB001) @@ -143,6 +273,57 @@ fn make_binary_array(n: usize) -> BinaryArray { fn make_ts_micros_array(n: usize) -> PrimitiveArray { make_ts_micros_array_with_tag(n, 0x7157_0001) } +#[inline] +fn make_utf8_array(n: usize) -> StringArray { + make_utf8_array_with_tag(n, 0x5712_07F8) +} +#[inline] +fn make_list_utf8_array(n: usize) -> ListArray { + make_list_utf8_array_with_tag(n, 0x0A11_57ED) +} +#[inline] +fn make_struct_array(n: usize) -> StructArray { + make_struct_array_with_tag(n, 0x57_AB_C7) +} + +#[inline] +fn make_list_utf8_array_with_tag(n: usize, tag: u64) -> ListArray { + let mut rng = rng_for(tag, n); + let mut builder = ListBuilder::new(StringBuilder::new()); + for _ in 0..n { + let items = rng.random_range(0..=5); + for _ in 0..items { + let s = rand_ascii_string(&mut rng, 1, 12); + builder.values().append_value(s.as_str()); + } + builder.append(true); + } + builder.finish() +} + +#[inline] +fn make_struct_array_with_tag(n: usize, tag: u64) -> StructArray { + let s_tag = tag ^ 0x5u64; + let i_tag = tag ^ 0x6u64; + let f_tag = tag ^ 0x7u64; + let s_col: ArrayRef = Arc::new(make_utf8_array_with_tag(n, s_tag)); + let i_col: ArrayRef = Arc::new(make_i32_array_with_tag(n, i_tag)); + let f_col: ArrayRef = Arc::new(make_f64_array_with_tag(n, f_tag)); + StructArray::from(vec![ + ( + Arc::new(Field::new("s1", DataType::Utf8, false)), + s_col.clone(), + ), + ( + Arc::new(Field::new("s2", DataType::Int32, false)), + i_col.clone(), + ), + ( + Arc::new(Field::new("s3", DataType::Float64, false)), + f_col.clone(), + ), + ]) +} #[inline] fn schema_single(name: &str, dt: DataType) -> Arc { @@ -159,6 +340,36 @@ fn schema_mixed() -> Arc { ])) } +#[inline] +fn schema_fixed16() -> Arc { + schema_single("field1", DataType::FixedSizeBinary(16)) +} + +#[inline] +fn schema_uuid16() -> Arc { + let mut md = HashMap::new(); + md.insert("logicalType".to_string(), "uuid".to_string()); + let field = Field::new("uuid", DataType::FixedSizeBinary(16), false).with_metadata(md); + Arc::new(Schema::new(vec![field])) +} + +#[inline] +fn schema_interval_mdn() -> Arc { + schema_single("duration", DataType::Interval(IntervalUnit::MonthDayNano)) +} + +#[inline] +fn schema_decimal_with_size(name: &str, dt: DataType, size_meta: Option) -> Arc { + let field = if let Some(size) = size_meta { + let mut md = HashMap::new(); + md.insert("size".to_string(), size.to_string()); + Field::new(name, dt, false).with_metadata(md) + } else { + Field::new(name, dt, false) + }; + Arc::new(Schema::new(vec![field])) +} + static BOOLEAN_DATA: Lazy> = Lazy::new(|| { let schema = schema_single("field1", DataType::Boolean); SIZES @@ -225,6 +436,40 @@ static BINARY_DATA: Lazy> = Lazy::new(|| { .collect() }); +static FIXED16_DATA: Lazy> = Lazy::new(|| { + let schema = schema_fixed16(); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_fixed16_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static UUID16_DATA: Lazy> = Lazy::new(|| { + let schema = schema_uuid16(); + SIZES + .iter() + .map(|&n| { + // Same values as Fixed16; writer path differs because of field metadata + let col: ArrayRef = Arc::new(make_fixed16_array_with_tag(n, 0x7575_6964_7575_6964)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static INTERVAL_MDN_DATA: Lazy> = Lazy::new(|| { + let schema = schema_interval_mdn(); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_interval_mdn_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + static TIMESTAMP_US_DATA: Lazy> = Lazy::new(|| { let schema = schema_single("field1", DataType::Timestamp(TimeUnit::Microsecond, None)); SIZES @@ -250,6 +495,190 @@ static MIXED_DATA: Lazy> = Lazy::new(|| { .collect() }); +static UTF8_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Utf8); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_utf8_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static LIST_UTF8_DATA: Lazy> = Lazy::new(|| { + // IMPORTANT: ListBuilder creates a child field named "item" that is nullable by default. + // Make the schema's list item nullable to match the array we construct. + let item_field = Arc::new(Field::new("item", DataType::Utf8, true)); + let schema = schema_single("field1", DataType::List(item_field)); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_list_utf8_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static STRUCT_DATA: Lazy> = Lazy::new(|| { + let struct_dt = DataType::Struct( + vec![ + Field::new("s1", DataType::Utf8, false), + Field::new("s2", DataType::Int32, false), + Field::new("s3", DataType::Float64, false), + ] + .into(), + ); + let schema = schema_single("field1", struct_dt); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_struct_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static DECIMAL32_DATA: Lazy> = Lazy::new(|| { + // Choose a representative precision/scale within Decimal32 limits + let precision: u8 = 7; + let scale: i8 = 2; + let schema = schema_single("amount", DataType::Decimal32(precision, scale)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal32_array_with_tag(n, 0xDEC_0032, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static DECIMAL64_DATA: Lazy> = Lazy::new(|| { + let precision: u8 = 13; + let scale: i8 = 3; + let schema = schema_single("amount", DataType::Decimal64(precision, scale)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal64_array_with_tag(n, 0xDEC_0064, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static DECIMAL128_BYTES_DATA: Lazy> = Lazy::new(|| { + let precision: u8 = 25; + let scale: i8 = 6; + let schema = schema_single("amount", DataType::Decimal128(precision, scale)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal128_array_with_tag(n, 0xDEC_0128, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static DECIMAL128_FIXED16_DATA: Lazy> = Lazy::new(|| { + // Same logical type as above but force Avro fixed(16) via metadata "size": "16" + let precision: u8 = 25; + let scale: i8 = 6; + let schema = + schema_decimal_with_size("amount", DataType::Decimal128(precision, scale), Some(16)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal128_array_with_tag(n, 0xDEC_F128, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static DECIMAL256_DATA: Lazy> = Lazy::new(|| { + // Use a higher precision typical of 256-bit decimals + let precision: u8 = 50; + let scale: i8 = 10; + let schema = schema_single("amount", DataType::Decimal256(precision, scale)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal256_array_with_tag(n, 0xDEC_0256, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static MAP_DATA: Lazy> = Lazy::new(|| { + use arrow_array::builder::{MapBuilder, StringBuilder}; + + let key_field = Arc::new(Field::new("keys", DataType::Utf8, false)); + let value_field = Arc::new(Field::new("values", DataType::Utf8, true)); + let entry_struct = Field::new( + "entries", + DataType::Struct(vec![key_field.as_ref().clone(), value_field.as_ref().clone()].into()), + false, + ); + let map_dt = DataType::Map(Arc::new(entry_struct), false); + let schema = schema_single("field1", map_dt); + + SIZES + .iter() + .map(|&n| { + // Build a MapArray with n rows + let mut builder = MapBuilder::new(None, StringBuilder::new(), StringBuilder::new()); + let mut rng = rng_for(0x00D0_0D1A, n); + for _ in 0..n { + let entries = rng.random_range(0..=5); + for _ in 0..entries { + let k = rand_ascii_string(&mut rng, 3, 10); + let v = rand_ascii_string(&mut rng, 0, 12); + // keys non-nullable, values nullable allowed but we provide non-null here + builder.keys().append_value(k); + builder.values().append_value(v); + } + builder.append(true).expect("Error building MapArray"); + } + let col: ArrayRef = Arc::new(builder.finish()); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static ENUM_DATA: Lazy> = Lazy::new(|| { + // To represent an Avro enum, the Arrow writer expects a Dictionary + // field with metadata specifying the enum symbols. + let enum_symbols = r#"["RED", "GREEN", "BLUE"]"#; + let mut metadata = HashMap::new(); + metadata.insert("avro.enum.symbols".to_string(), enum_symbols.to_string()); + + let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let field = Field::new("color_enum", dict_type, false).with_metadata(metadata); + let schema = Arc::new(Schema::new(vec![field])); + + let dict_values: ArrayRef = Arc::new(StringArray::from(vec!["RED", "GREEN", "BLUE"])); + + SIZES + .iter() + .map(|&n| { + use arrow_array::DictionaryArray; + let mut rng = rng_for(0x3A7A, n); + let keys_vec: Vec = (0..n).map(|_| rng.random_range(0..=2)).collect(); + let keys = PrimitiveArray::::from(keys_vec); + + let dict_array = + DictionaryArray::::try_new(keys, dict_values.clone()).unwrap(); + let col: ArrayRef = Arc::new(dict_array); + + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + fn ocf_size_for_batch(batch: &RecordBatch) -> usize { let schema_owned: Schema = (*batch.schema()).clone(); let cursor = Cursor::new(Vec::::with_capacity(1024)); @@ -314,6 +743,19 @@ fn criterion_benches(c: &mut Criterion) { bench_writer_scenario(c, "write-Binary(Bytes)", &BINARY_DATA); bench_writer_scenario(c, "write-TimestampMicros", &TIMESTAMP_US_DATA); bench_writer_scenario(c, "write-Mixed", &MIXED_DATA); + bench_writer_scenario(c, "write-Utf8", &UTF8_DATA); + bench_writer_scenario(c, "write-List", &LIST_UTF8_DATA); + bench_writer_scenario(c, "write-Struct", &STRUCT_DATA); + bench_writer_scenario(c, "write-FixedSizeBinary16", &FIXED16_DATA); + bench_writer_scenario(c, "write-UUID(logicalType)", &UUID16_DATA); + bench_writer_scenario(c, "write-IntervalMonthDayNanoDuration", &INTERVAL_MDN_DATA); + bench_writer_scenario(c, "write-Decimal32(bytes)", &DECIMAL32_DATA); + bench_writer_scenario(c, "write-Decimal64(bytes)", &DECIMAL64_DATA); + bench_writer_scenario(c, "write-Decimal128(bytes)", &DECIMAL128_BYTES_DATA); + bench_writer_scenario(c, "write-Decimal128(fixed16)", &DECIMAL128_FIXED16_DATA); + bench_writer_scenario(c, "write-Decimal256(bytes)", &DECIMAL256_DATA); + bench_writer_scenario(c, "write-Map", &MAP_DATA); + bench_writer_scenario(c, "write-Enum", &ENUM_DATA); } criterion_group! { diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs index d80a3e739a63..39c70c868821 100644 --- a/arrow-avro/src/writer/encoder.rs +++ b/arrow-avro/src/writer/encoder.rs @@ -363,6 +363,60 @@ impl<'a> FieldEncoder<'a> { .ok_or_else(|| ArrowError::SchemaError("Expected FixedSizeBinaryArray".into()))?; Encoder::Uuid(UuidEncoder(arr)) } + FieldPlan::Map { values_nullability, + value_plan } => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected MapArray".into()))?; + Encoder::Map(Box::new(MapEncoder::try_new(arr, *values_nullability, value_plan.as_ref())?)) + } + FieldPlan::Enum { symbols} => match array.data_type() { + DataType::Dictionary(key_dt, value_dt) => { + if **key_dt != DataType::Int32 || **value_dt != DataType::Utf8 { + return Err(ArrowError::SchemaError( + "Avro enum requires Dictionary".into(), + )); + } + let dict = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::SchemaError("Expected DictionaryArray".into()) + })?; + + let values = dict + .values() + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::SchemaError("Dictionary values must be Utf8".into()) + })?; + if values.len() != symbols.len() { + return Err(ArrowError::SchemaError(format!( + "Enum symbol length {} != dictionary size {}", + symbols.len(), + values.len() + ))); + } + for i in 0..values.len() { + if values.value(i) != symbols[i].as_str() { + return Err(ArrowError::SchemaError(format!( + "Enum symbol mismatch at {i}: schema='{}' dict='{}'", + symbols[i], + values.value(i) + ))); + } + } + let keys = dict.keys(); + Encoder::Enum(EnumEncoder { keys }) + } + other => { + return Err(ArrowError::SchemaError(format!( + "Avro enum site requires DataType::Dictionary, found: {other:?}" + ))) + } + } other => { return Err(ArrowError::NotYetImplemented(format!( "Avro writer: {other:?} not yet supported", @@ -443,6 +497,14 @@ enum FieldPlan { Decimal { size: Option }, /// Avro UUID logical type (fixed) Uuid, + /// Avro map with value‑site nullability and nested plan + Map { + values_nullability: Option, + value_plan: Box, + }, + /// Avro enum; maps to Arrow Dictionary with dictionary values + /// exactly equal and ordered as the Avro enum `symbols`. + Enum { symbols: Arc<[String]> }, } #[derive(Debug, Clone)] @@ -631,6 +693,54 @@ impl FieldPlan { "Avro array maps to Arrow List/LargeList, found: {other:?}" ))), }, + Codec::Map(values_dt) => { + let entries_field = match arrow_field.data_type() { + DataType::Map(entries, _sorted) => entries.as_ref(), + other => { + return Err(ArrowError::SchemaError(format!( + "Avro map maps to Arrow DataType::Map, found: {other:?}" + ))) + } + }; + let entries_struct_fields = match entries_field.data_type() { + DataType::Struct(fs) => fs, + other => { + return Err(ArrowError::SchemaError(format!( + "Arrow Map entries must be Struct, found: {other:?}" + ))) + } + }; + let value_idx = + find_map_value_field_index(entries_struct_fields).ok_or_else(|| { + ArrowError::SchemaError("Map entries struct missing value field".into()) + })?; + let value_field = entries_struct_fields[value_idx].as_ref(); + let value_plan = FieldPlan::build(values_dt.as_ref(), value_field)?; + Ok(FieldPlan::Map { + values_nullability: values_dt.nullability(), + value_plan: Box::new(value_plan), + }) + } + Codec::Enum(symbols) => match arrow_field.data_type() { + DataType::Dictionary(key_dt, value_dt) => { + if **key_dt != DataType::Int32 { + return Err(ArrowError::SchemaError( + "Avro enum requires Dictionary".into(), + )); + } + if **value_dt != DataType::Utf8 { + return Err(ArrowError::SchemaError( + "Avro enum requires Dictionary".into(), + )); + } + Ok(FieldPlan::Enum { + symbols: symbols.clone(), + }) + } + other => Err(ArrowError::SchemaError(format!( + "Avro enum maps to Arrow Dictionary, found: {other:?}" + ))), + }, // decimal site (bytes or fixed(N)) with precision/scale validation Codec::Decimal(precision, scale_opt, fixed_size_opt) => { let (ap, as_) = match arrow_field.data_type() { @@ -700,6 +810,9 @@ enum Encoder<'a> { Decimal64(Decimal64Encoder<'a>), Decimal128(Decimal128Encoder<'a>), Decimal256(Decimal256Encoder<'a>), + /// Avro `enum` encoder: writes the key (int) as the enum index. + Enum(EnumEncoder<'a>), + Map(Box>), } impl<'a> Encoder<'a> { @@ -730,6 +843,8 @@ impl<'a> Encoder<'a> { Encoder::Decimal64(e) => (e).encode(out, idx), Encoder::Decimal128(e) => (e).encode(out, idx), Encoder::Decimal256(e) => (e).encode(out, idx), + Encoder::Map(e) => (e).encode(out, idx), + Encoder::Enum(e) => (e).encode(out, idx), } } } @@ -795,6 +910,130 @@ impl<'a, O: OffsetSizeTrait> Utf8GenericEncoder<'a, O> { type Utf8Encoder<'a> = Utf8GenericEncoder<'a, i32>; type Utf8LargeEncoder<'a> = Utf8GenericEncoder<'a, i64>; + +/// Internal key array kind used by Map encoder. +enum KeyKind<'a> { + Utf8(&'a GenericStringArray), + LargeUtf8(&'a GenericStringArray), +} +struct MapEncoder<'a> { + map: &'a MapArray, + keys: KeyKind<'a>, + values: FieldEncoder<'a>, + keys_offset: usize, + values_offset: usize, +} + +fn encode_map_entries( + out: &mut W, + keys: &GenericStringArray, + keys_offset: usize, + start: usize, + end: usize, + mut write_item: impl FnMut(&mut W, usize) -> Result<(), ArrowError>, +) -> Result<(), ArrowError> +where + W: Write + ?Sized, + O: OffsetSizeTrait, +{ + encode_blocked_range(out, start, end, |out, j| { + let j_key = j.saturating_sub(keys_offset); + write_len_prefixed(out, keys.value(j_key).as_bytes())?; + write_item(out, j) + }) +} + +impl<'a> MapEncoder<'a> { + fn try_new( + map: &'a MapArray, + values_nullability: Option, + value_plan: &FieldPlan, + ) -> Result { + let keys_arr = map.keys(); + let keys_kind = match keys_arr.data_type() { + DataType::Utf8 => KeyKind::Utf8(keys_arr.as_string::()), + DataType::LargeUtf8 => KeyKind::LargeUtf8(keys_arr.as_string::()), + other => { + return Err(ArrowError::SchemaError(format!( + "Avro map requires string keys; Arrow key type must be Utf8/LargeUtf8, found: {other:?}" + ))) + } + }; + + let entries_struct_fields = match map.data_type() { + DataType::Map(entries, _) => match entries.data_type() { + DataType::Struct(fs) => fs, + other => { + return Err(ArrowError::SchemaError(format!( + "Arrow Map entries must be Struct, found: {other:?}" + ))) + } + }, + _ => { + return Err(ArrowError::SchemaError( + "Expected MapArray with DataType::Map".into(), + )) + } + }; + + let v_idx = find_map_value_field_index(entries_struct_fields).ok_or_else(|| { + ArrowError::SchemaError("Map entries struct missing value field".into()) + })?; + let value_field = entries_struct_fields[v_idx].as_ref(); + + let values_enc = prepare_value_site_encoder( + map.values().as_ref(), + value_field, + values_nullability, + value_plan, + )?; + + Ok(Self { + map, + keys: keys_kind, + values: values_enc, + keys_offset: keys_arr.offset(), + values_offset: map.values().offset(), + }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let offsets = self.map.offsets(); + let start = offsets[idx] as usize; + let end = offsets[idx + 1] as usize; + + let mut write_item = |out: &mut W, j: usize| { + let j_val = j.saturating_sub(self.values_offset); + self.values.encode(out, j_val) + }; + + match self.keys { + KeyKind::Utf8(arr) => { + encode_map_entries(out, arr, self.keys_offset, start, end, write_item) + } + KeyKind::LargeUtf8(arr) => { + encode_map_entries(out, arr, self.keys_offset, start, end, write_item) + } + } + } +} + +/// Avro `enum` encoder for Arrow `DictionaryArray`. +/// +/// Per Avro spec, an enum is encoded as an **int** equal to the +/// zero-based position of the symbol in the schema’s `symbols` list. +/// We validate at construction that the dictionary values equal the symbols, +/// so we can directly write the key value here. +struct EnumEncoder<'a> { + keys: &'a PrimitiveArray, +} +impl EnumEncoder<'_> { + fn encode(&mut self, out: &mut W, row: usize) -> Result<(), ArrowError> { + let idx = self.keys.value(row); + write_int(out, idx) + } +} + struct StructEncoder<'a> { encoders: Vec>, } @@ -1314,6 +1553,25 @@ mod tests { assert_bytes_eq(&got, &expected); } + #[test] + fn enum_encoder_dictionary() { + // symbols: ["A","B","C"], keys [2,0,1] + let dict_values = StringArray::from(vec!["A", "B", "C"]); + let keys = Int32Array::from(vec![2, 0, 1]); + let dict = + DictionaryArray::::try_new(keys, Arc::new(dict_values) as ArrayRef).unwrap(); + let symbols = Arc::<[String]>::from( + vec!["A".to_string(), "B".to_string(), "C".to_string()].into_boxed_slice(), + ); + let plan = FieldPlan::Enum { symbols }; + let got = encode_all(&dict, &plan, None); + let mut expected = Vec::new(); + expected.extend(avro_long_bytes(2)); + expected.extend(avro_long_bytes(0)); + expected.extend(avro_long_bytes(1)); + assert_bytes_eq(&got, &expected); + } + #[test] fn decimal_bytes_and_fixed() { // Use Decimal128 with small positives and negatives @@ -1498,6 +1756,48 @@ mod tests { } } + #[test] + fn map_encoder_string_keys_int_values() { + // Build MapArray with two rows + // Row0: {"k1":1, "k2":2} + // Row1: {} + let keys = StringArray::from(vec!["k1", "k2"]); + let values = Int32Array::from(vec![1, 2]); + let entries_fields = Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ]); + let entries = StructArray::new( + entries_fields, + vec![Arc::new(keys) as ArrayRef, Arc::new(values) as ArrayRef], + None, + ); + let offsets = arrow_buffer::OffsetBuffer::new(vec![0i32, 2, 2].into()); + let map = MapArray::new( + Field::new("entries", entries.data_type().clone(), false).into(), + offsets, + entries, + None, + false, + ); + let plan = FieldPlan::Map { + values_nullability: None, + value_plan: Box::new(FieldPlan::Scalar), + }; + let got = encode_all(&map, &plan, None); + let mut expected = Vec::new(); + // Row0: block 2 then pairs + expected.extend(avro_long_bytes(2)); + expected.extend(avro_len_prefixed_bytes(b"k1")); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_len_prefixed_bytes(b"k2")); + expected.extend(avro_long_bytes(2)); + expected.extend(avro_long_bytes(0)); + // Row1: empty + expected.extend(avro_long_bytes(0)); + assert_bytes_eq(&got, &expected); + } + #[test] fn list64_encoder_int32() { // LargeList [[1,2,3], []] diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs index a5b2691bb816..f5e84eeb50bb 100644 --- a/arrow-avro/src/writer/mod.rs +++ b/arrow-avro/src/writer/mod.rs @@ -415,4 +415,222 @@ mod tests { ); Ok(()) } + + #[test] + fn test_round_trip_simple_fixed_ocf() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/simple_fixed.avro"); + let rdr_file = File::open(&path).expect("open avro/simple_fixed.avro"); + let mut reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build avro reader"); + let schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&schema, &input_batches).expect("concat input"); + let tmp = NamedTempFile::new().expect("create temp file"); + let out_file = File::create(tmp.path()).expect("create temp avro"); + let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + drop(writer); + let rt_file = File::open(tmp.path()).expect("open round_trip avro"); + let mut rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!(round_trip, original); + Ok(()) + } + + #[cfg(not(feature = "canonical_extension_types"))] + #[test] + fn test_round_trip_duration_and_uuid_ocf() -> Result<(), ArrowError> { + let in_file = + File::open("test/data/duration_uuid.avro").expect("open test/data/duration_uuid.avro"); + let mut reader = ReaderBuilder::new() + .build(BufReader::new(in_file)) + .expect("build reader for duration_uuid.avro"); + let in_schema = reader.schema(); + let has_mdn = in_schema.fields().iter().any(|f| { + matches!( + f.data_type(), + DataType::Interval(IntervalUnit::MonthDayNano) + ) + }); + assert!( + has_mdn, + "expected at least one Interval(MonthDayNano) field in duration_uuid.avro" + ); + let has_uuid_fixed = in_schema + .fields() + .iter() + .any(|f| matches!(f.data_type(), DataType::FixedSizeBinary(16))); + assert!( + has_uuid_fixed, + "expected at least one FixedSizeBinary(16) (uuid) field in duration_uuid.avro" + ); + let input_batches = reader.collect::, _>>()?; + let input = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + let tmp = NamedTempFile::new().expect("create temp file"); + { + let out_file = File::create(tmp.path()).expect("create temp avro"); + let mut writer = AvroWriter::new(out_file, in_schema.as_ref().clone())?; + writer.write(&input)?; + writer.finish()?; + } + let rt_file = File::open(tmp.path()).expect("open round_trip avro"); + let mut rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!(round_trip, input); + Ok(()) + } + + // This test reads the same 'nonnullable.impala.avro' used by the reader tests, + // writes it back out with the writer (hitting Map encoding paths), then reads it + // again and asserts exact Arrow equivalence. + #[test] + fn test_nonnullable_impala_roundtrip_writer() -> Result<(), ArrowError> { + // Load source Avro with Map fields + let path = arrow_test_data("avro/nonnullable.impala.avro"); + let rdr_file = File::open(&path).expect("open avro/nonnullable.impala.avro"); + let mut reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for nonnullable.impala.avro"); + // Collect all input batches and concatenate to a single RecordBatch + let in_schema = reader.schema(); + // Sanity: ensure the file actually contains at least one Map field + let has_map = in_schema + .fields() + .iter() + .any(|f| matches!(f.data_type(), DataType::Map(_, _))); + assert!( + has_map, + "expected at least one Map field in avro/nonnullable.impala.avro" + ); + + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + // Write out using the OCF writer into an in-memory Vec + let buffer = Vec::::new(); + let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let out_bytes = writer.into_inner(); + // Read the produced bytes back with the Reader + let mut rt_reader = ReaderBuilder::new() + .build(Cursor::new(out_bytes)) + .expect("build reader for round-tripped in-memory OCF"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); + // Exact value fidelity (schema + data) + assert_eq!( + roundtrip, original, + "Round-trip Avro map data mismatch for nonnullable.impala.avro" + ); + Ok(()) + } + + #[test] + fn test_roundtrip_decimals_via_writer() -> Result<(), ArrowError> { + // (file, resolve via ARROW_TEST_DATA?) + let files: [(&str, bool); 8] = [ + ("avro/fixed_length_decimal.avro", true), // fixed-backed -> Decimal128(25,2) + ("avro/fixed_length_decimal_legacy.avro", true), // legacy fixed[8] -> Decimal64(13,2) + ("avro/int32_decimal.avro", true), // bytes-backed -> Decimal32(4,2) + ("avro/int64_decimal.avro", true), // bytes-backed -> Decimal64(10,2) + ("test/data/int256_decimal.avro", false), // bytes-backed -> Decimal256(76,2) + ("test/data/fixed256_decimal.avro", false), // fixed[32]-backed -> Decimal256(76,10) + ("test/data/fixed_length_decimal_legacy_32.avro", false), // legacy fixed[4] -> Decimal32(9,2) + ("test/data/int128_decimal.avro", false), // bytes-backed -> Decimal128(38,2) + ]; + for (rel, in_test_data_dir) in files { + // Resolve path the same way as reader::test_decimal + let path: String = if in_test_data_dir { + arrow_test_data(rel) + } else { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(rel) + .to_string_lossy() + .into_owned() + }; + // Read original file into a single RecordBatch for comparison + let f_in = File::open(&path).expect("open input avro"); + let mut rdr = ReaderBuilder::new().build(BufReader::new(f_in))?; + let in_schema = rdr.schema(); + let in_batches = rdr.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &in_batches).expect("concat input"); + // Write it out with the OCF writer (no special compression) + let tmp = NamedTempFile::new().expect("create temp file"); + let out_path = tmp.into_temp_path(); + let out_file = File::create(&out_path).expect("create temp avro"); + let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + // Read back the file we just wrote and compare equality (schema + data) + let f_rt = File::open(&out_path).expect("open roundtrip avro"); + let mut rt_rdr = ReaderBuilder::new().build(BufReader::new(f_rt))?; + let rt_schema = rt_rdr.schema(); + let rt_batches = rt_rdr.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat rt"); + assert_eq!(roundtrip, original, "decimal round-trip mismatch for {rel}"); + } + Ok(()) + } + + #[test] + fn test_enum_roundtrip_uses_reader_fixture() -> Result<(), ArrowError> { + // Read the known-good enum file (same as reader::test_simple) + let path = arrow_test_data("avro/simple_enum.avro"); + let rdr_file = File::open(&path).expect("open avro/simple_enum.avro"); + let mut reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for simple_enum.avro"); + // Concatenate all batches to one RecordBatch for a clean equality check + let in_schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + // Sanity: expect at least one Dictionary(Int32, Utf8) column (enum) + let has_enum_dict = in_schema.fields().iter().any(|f| { + matches!( + f.data_type(), + DataType::Dictionary(k, v) if **k == DataType::Int32 && **v == DataType::Utf8 + ) + }); + assert!( + has_enum_dict, + "Expected at least one enum-mapped Dictionary field" + ); + // Write with OCF writer into memory using the reader-provided Arrow schema. + // The writer will embed the Avro JSON from `avro.schema` metadata if present. + let buffer: Vec = Vec::new(); + let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let bytes = writer.into_inner(); + // Read back and compare for exact equality (schema + data) + let mut rt_reader = ReaderBuilder::new() + .build(Cursor::new(bytes)) + .expect("reader for round-trip"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); + assert_eq!(roundtrip, original, "Avro enum round-trip mismatch"); + Ok(()) + } } From d329aef6eedb02c98f7d33118ce0c2b7ce2c8fd6 Mon Sep 17 00:00:00 2001 From: nathaniel-d-ef Date: Tue, 16 Sep 2025 14:24:36 +0200 Subject: [PATCH 02/10] Eliminate unnecessary variable assignment Co-authored-by: Connor Sanders --- arrow-avro/src/writer/encoder.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs index 39c70c868821..fea37bb5b014 100644 --- a/arrow-avro/src/writer/encoder.rs +++ b/arrow-avro/src/writer/encoder.rs @@ -1029,8 +1029,7 @@ struct EnumEncoder<'a> { } impl EnumEncoder<'_> { fn encode(&mut self, out: &mut W, row: usize) -> Result<(), ArrowError> { - let idx = self.keys.value(row); - write_int(out, idx) + write_int(out, self.keys.value(row)) } } From 47512dcd4f94a024e324ae3ce1a2a47f4ce84412 Mon Sep 17 00:00:00 2001 From: nathaniel-d-ef Date: Tue, 16 Sep 2025 14:42:27 +0200 Subject: [PATCH 03/10] Move `encode_map_entries` into `MapEncoder` implementation to simplify code structure. --- arrow-avro/src/writer/encoder.rs | 60 +++++++++++++++++++------------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs index fea37bb5b014..fd619249617e 100644 --- a/arrow-avro/src/writer/encoder.rs +++ b/arrow-avro/src/writer/encoder.rs @@ -924,25 +924,6 @@ struct MapEncoder<'a> { values_offset: usize, } -fn encode_map_entries( - out: &mut W, - keys: &GenericStringArray, - keys_offset: usize, - start: usize, - end: usize, - mut write_item: impl FnMut(&mut W, usize) -> Result<(), ArrowError>, -) -> Result<(), ArrowError> -where - W: Write + ?Sized, - O: OffsetSizeTrait, -{ - encode_blocked_range(out, start, end, |out, j| { - let j_key = j.saturating_sub(keys_offset); - write_len_prefixed(out, keys.value(j_key).as_bytes())?; - write_item(out, j) - }) -} - impl<'a> MapEncoder<'a> { fn try_new( map: &'a MapArray, @@ -997,6 +978,25 @@ impl<'a> MapEncoder<'a> { }) } + fn encode_map_entries( + out: &mut W, + keys: &GenericStringArray, + keys_offset: usize, + start: usize, + end: usize, + mut write_item: impl FnMut(&mut W, usize) -> Result<(), ArrowError>, + ) -> Result<(), ArrowError> + where + W: Write + ?Sized, + O: OffsetSizeTrait, + { + encode_blocked_range(out, start, end, |out, j| { + let j_key = j.saturating_sub(keys_offset); + write_len_prefixed(out, keys.value(j_key).as_bytes())?; + write_item(out, j) + }) + } + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { let offsets = self.map.offsets(); let start = offsets[idx] as usize; @@ -1008,12 +1008,22 @@ impl<'a> MapEncoder<'a> { }; match self.keys { - KeyKind::Utf8(arr) => { - encode_map_entries(out, arr, self.keys_offset, start, end, write_item) - } - KeyKind::LargeUtf8(arr) => { - encode_map_entries(out, arr, self.keys_offset, start, end, write_item) - } + KeyKind::Utf8(arr) => MapEncoder::<'a>::encode_map_entries( + out, + arr, + self.keys_offset, + start, + end, + write_item, + ), + KeyKind::LargeUtf8(arr) => MapEncoder::<'a>::encode_map_entries( + out, + arr, + self.keys_offset, + start, + end, + write_item, + ), } } } From 749ac2bf1957c7ef32669fda75e95b4854fe581f Mon Sep 17 00:00:00 2001 From: nathaniel-d-ef Date: Wed, 17 Sep 2025 16:03:43 +0200 Subject: [PATCH 04/10] Add single-object encoding support and tests to Avro stream writer --- arrow-avro/src/writer/encoder.rs | 18 ++++- arrow-avro/src/writer/format.rs | 94 +++++++++++++++++++++--- arrow-avro/src/writer/mod.rs | 118 ++++++++++++++++++++++++++++++- arrow-schema/src/schema.rs | 2 +- 4 files changed, 217 insertions(+), 15 deletions(-) diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs index fd619249617e..6bf4ab1650e4 100644 --- a/arrow-avro/src/writer/encoder.rs +++ b/arrow-avro/src/writer/encoder.rs @@ -18,7 +18,7 @@ //! Avro Encoder for Arrow types. use crate::codec::{AvroDataType, AvroField, Codec}; -use crate::schema::Nullability; +use crate::schema::{Fingerprint, Nullability}; use arrow_array::cast::AsArray; use arrow_array::types::{ ArrowPrimitiveType, Float32Type, Float64Type, Int32Type, Int64Type, IntervalDayTimeType, @@ -33,6 +33,7 @@ use arrow_array::{ use arrow_array::{Decimal32Array, Decimal64Array}; use arrow_buffer::NullBuffer; use arrow_schema::{ArrowError, DataType, Field, IntervalUnit, Schema as ArrowSchema, TimeUnit}; +use serde::Serialize; use std::io::Write; use std::sync::Arc; use uuid::Uuid; @@ -600,9 +601,22 @@ impl RecordEncoder { /// Encode a `RecordBatch` using this encoder plan. /// /// Tip: Wrap `out` in a `std::io::BufWriter` to reduce the overhead of many small writes. - pub fn encode(&self, out: &mut W, batch: &RecordBatch) -> Result<(), ArrowError> { + pub fn encode( + &self, + out: &mut W, + batch: &RecordBatch, + prefix: Option<&[u8]>, + ) -> Result<(), ArrowError> { let mut column_encoders = self.prepare_for_batch(batch)?; for row in 0..batch.num_rows() { + if let Some(prefix) = prefix { + if !prefix.is_empty() { + out.write_all(prefix).map_err(|e| { + ArrowError::IoError(format!("write single-object prefix: {e}"), e) + })?; + } + } + for encoder in column_encoders.iter_mut() { encoder.encode(out, row)?; } diff --git a/arrow-avro/src/writer/format.rs b/arrow-avro/src/writer/format.rs index 6fac9e8286a2..6d5ac7e1e0bf 100644 --- a/arrow-avro/src/writer/format.rs +++ b/arrow-avro/src/writer/format.rs @@ -16,7 +16,7 @@ // under the License. use crate::compression::{CompressionCodec, CODEC_METADATA_KEY}; -use crate::schema::{AvroSchema, SCHEMA_METADATA_KEY}; +use crate::schema::{AvroSchema, Fingerprint, SCHEMA_METADATA_KEY, SINGLE_OBJECT_MAGIC}; use crate::writer::encoder::write_long; use arrow_schema::{ArrowError, Schema}; use rand::RngCore; @@ -26,6 +26,7 @@ use std::io::Write; /// Format abstraction implemented by each container‐level writer. pub trait AvroFormat: Debug + Default { /// Write any bytes required at the very beginning of the output stream + /// (file header, etc.). /// Implementations **must not** write any record data. fn start_stream( &mut self, @@ -36,6 +37,17 @@ pub trait AvroFormat: Debug + Default { /// Return the 16‑byte sync marker (OCF) or `None` (binary stream). fn sync_marker(&self) -> Option<&[u8; 16]>; + + /// Return the 10‑byte **Avro single‑object** prefix (`C3 01` magic + + /// little‑endian schema fingerprint) to be written **before each record**, + /// or `None` if the format does not use single‑object encoding. + /// + /// The default implementation returns `None`. `AvroBinaryFormat` overrides + /// this to return the appropriate single-object encoding prefix. + #[inline] + fn single_object_prefix(&self) -> Option<&[u8]> { + None + } } /// Avro Object Container File (OCF) format writer. @@ -53,10 +65,15 @@ impl AvroFormat for AvroOcfFormat { ) -> Result<(), ArrowError> { let mut rng = rand::rng(); rng.fill_bytes(&mut self.sync_marker); + // Choose the Avro schema JSON that the file will advertise. + // If `schema.metadata[SCHEMA_METADATA_KEY]` exists, AvroSchema::try_from + // uses it verbatim; otherwise it is generated from the Arrow schema. let avro_schema = AvroSchema::try_from(schema)?; + // Magic writer .write_all(b"Obj\x01") .map_err(|e| ArrowError::IoError(format!("write OCF magic: {e}"), e))?; + // File metadata map: { "avro.schema": , "avro.codec": } let codec_str = match compression { Some(CompressionCodec::Deflate) => "deflate", Some(CompressionCodec::Snappy) => "snappy", @@ -65,6 +82,7 @@ impl AvroFormat for AvroOcfFormat { Some(CompressionCodec::Xz) => "xz", None => "null", }; + // Map block: count=2, then key/value pairs, then terminating count=0 write_long(writer, 2)?; write_string(writer, SCHEMA_METADATA_KEY)?; write_bytes(writer, avro_schema.json_string.as_bytes())?; @@ -75,7 +93,6 @@ impl AvroFormat for AvroOcfFormat { writer .write_all(&self.sync_marker) .map_err(|e| ArrowError::IoError(format!("write OCF sync marker: {e}"), e))?; - Ok(()) } @@ -84,25 +101,84 @@ impl AvroFormat for AvroOcfFormat { } } -/// Raw Avro binary streaming format (no header or footer). +/// Raw Avro binary streaming format using **Single-Object Encoding** per record. +/// +/// Each record written by the stream writer is framed with a prefix determined +/// by the schema fingerprinting algorithm. +/// +/// See: +/// See: #[derive(Debug, Default)] -pub struct AvroBinaryFormat; +pub struct AvroBinaryFormat { + /// Pre-built, variable-length prefix written before each record. + prefix: Vec, +} impl AvroFormat for AvroBinaryFormat { fn start_stream( &mut self, _writer: &mut W, - _schema: &Schema, - _compression: Option, + schema: &Schema, + compression: Option, ) -> Result<(), ArrowError> { - Err(ArrowError::NotYetImplemented( - "avro binary format not yet implemented".to_string(), - )) + if compression.is_some() { + return Err(ArrowError::InvalidArgumentError( + "Compression not supported for Avro binary streaming (single-object encoding)" + .to_string(), + )); + } + + if let Some(id_str) = schema.metadata().get("confluent.schema.id") { + let id: u32 = id_str.parse().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Invalid Confluent schema ID in metadata: {id_str}" + )) + })?; + self.prefix.clear(); + self.prefix.push(0x00); + self.prefix.extend_from_slice(&id.to_be_bytes()); + return Ok(()); + } + + let avro_schema = AvroSchema::try_from(schema)?; + self.prefix.clear(); + + match avro_schema.fingerprint()? { + // Case 1: Confluent Schema Registry ID format. + // 1 magic byte (0x00) + 4-byte schema ID (Big Endian). + Fingerprint::Id(id) => { + self.prefix.push(0x00); + self.prefix.extend_from_slice(&id.to_be_bytes()); + } + + // Case 2: Standard single-object encoding with hash-based fingerprints. + // 2 magic bytes + N-byte fingerprint. + fp => { + self.prefix.extend_from_slice(&SINGLE_OBJECT_MAGIC); + match fp { + Fingerprint::Rabin(val) => self.prefix.extend_from_slice(&val.to_le_bytes()), + #[cfg(feature = "md5")] + Fingerprint::MD5(val) => self.prefix.extend_from_slice(val.as_ref()), + #[cfg(feature = "sha256")] + Fingerprint::SHA256(val) => self.prefix.extend_from_slice(val.as_ref()), + Fingerprint::Id(_) => unreachable!(), + } + } + } + Ok(()) } fn sync_marker(&self) -> Option<&[u8; 16]> { None } + + fn single_object_prefix(&self) -> Option<&[u8]> { + if self.prefix.is_empty() { + None + } else { + Some(&self.prefix) + } + } } #[inline] diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs index f5e84eeb50bb..636e8b34c39c 100644 --- a/arrow-avro/src/writer/mod.rs +++ b/arrow-avro/src/writer/mod.rs @@ -177,7 +177,8 @@ impl Writer { fn write_ocf_block(&mut self, batch: &RecordBatch, sync: &[u8; 16]) -> Result<(), ArrowError> { let mut buf = Vec::::with_capacity(1024); - self.encoder.encode(&mut buf, batch)?; + self.encoder + .encode(&mut buf, batch, self.format.single_object_prefix())?; let encoded = match self.compression { Some(codec) => codec.compress(&buf)?, None => buf, @@ -194,7 +195,9 @@ impl Writer { } fn write_stream(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { - self.encoder.encode(&mut self.writer, batch) + self.encoder + .encode(&mut self.writer, batch, self.format.single_object_prefix())?; + Ok(()) } } @@ -205,7 +208,7 @@ mod tests { use crate::reader::ReaderBuilder; use crate::schema::{AvroSchema, SchemaStore}; use crate::test_util::arrow_test_data; - use arrow_array::{ArrayRef, BinaryArray, Int32Array, RecordBatch}; + use arrow_array::{ArrayRef, BinaryArray, Int32Array, Int64Array, RecordBatch}; use arrow_schema::{DataType, Field, IntervalUnit, Schema}; use std::fs::File; use std::io::{BufReader, Cursor}; @@ -230,6 +233,115 @@ mod tests { .expect("failed to build test RecordBatch") } + #[test] + fn test_stream_writer_writes_prefix_per_row() -> Result<(), ArrowError> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let avro_schema = AvroSchema::try_from(&schema)?; + + let fingerprint = avro_schema.fingerprint()?; + let mut expected_prefix = Vec::from(crate::schema::SINGLE_OBJECT_MAGIC); + match fingerprint { + crate::schema::Fingerprint::Rabin(val) => expected_prefix.extend(val.to_le_bytes()), + _ => panic!("Expected Rabin fingerprint for default stream writer"), + } + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef], + )?; + + let buffer: Vec = Vec::new(); + let mut writer = AvroStreamWriter::new(buffer, schema)?; + writer.write(&batch)?; + let actual_bytes = writer.into_inner(); + + let mut expected_bytes = Vec::new(); + // Row 1: prefix + zig-zag encoded(10) + expected_bytes.extend(&expected_prefix); + expected_bytes.push(0x14); + // Row 2: prefix + zig-zag encoded(20) + expected_bytes.extend(&expected_prefix); + expected_bytes.push(0x28); + + assert_eq!( + actual_bytes, expected_bytes, + "Stream writer output did not match expected prefix-per-row format" + ); + Ok(()) + } + + #[test] + fn test_stream_writer_with_id_fingerprint() -> Result<(), ArrowError> { + // 1. Schema with Confluent ID in metadata + let schema_id = 42u32; + let mut metadata = std::collections::HashMap::new(); + // Assume "confluent.schema.id" is the metadata key the implementation looks for + metadata.insert("confluent.schema.id".to_string(), schema_id.to_string()); + let schema = + Schema::new(vec![Field::new("value", DataType::Int64, false)]).with_metadata(metadata); + + // 2. Batch with two rows + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int64Array::from(vec![100, 200])) as ArrayRef], + )?; + + // 3. Write the batch using the stream writer + let buffer: Vec = Vec::new(); + let mut writer = AvroStreamWriter::new(buffer, schema)?; + writer.write(&batch)?; + let actual_bytes = writer.into_inner(); + + // 4. Construct expected output manually for Confluent wire format + let mut expected_bytes: Vec = Vec::new(); + let prefix = { + let mut p = vec![0x00]; // Confluent magic byte + p.extend(&schema_id.to_be_bytes()); + p + }; + + // Row 1: prefix + zig-zag encoded(100) -> 200 -> [0xC8, 0x01] + expected_bytes.extend(&prefix); + expected_bytes.extend(&[0xC8, 0x01]); + // Row 2: prefix + zig-zag encoded(200) -> 400 -> [0x90, 0x03] + expected_bytes.extend(&prefix); + expected_bytes.extend(&[0x90, 0x03]); + + // 5. Assert + assert_eq!( + actual_bytes, expected_bytes, + "Stream writer output for Confluent ID did not match expected format" + ); + Ok(()) + } + + #[test] + fn test_stream_writer_invalid_id_fingerprint_errors() { + let mut metadata = std::collections::HashMap::new(); + metadata.insert( + "confluent.schema.id".to_string(), + "not-a-valid-id".to_string(), + ); + let schema = + Schema::new(vec![Field::new("value", DataType::Int64, false)]).with_metadata(metadata); + + let buffer: Vec = Vec::new(); + let result = AvroStreamWriter::new(buffer, schema); + + let err = result.expect_err("Writer creation should fail for invalid schema ID"); + assert!( + matches!(err, ArrowError::InvalidArgumentError(_)), + "Expected InvalidArgumentError, but got {:?}", + err + ); + assert!( + err.to_string() + .contains("Invalid Confluent schema ID in metadata"), + "Error message did not match expectation. Got: {}", + err + ); + } + #[test] fn test_ocf_writer_generates_header_and_sync() -> Result<(), ArrowError> { let batch = make_batch(); diff --git a/arrow-schema/src/schema.rs b/arrow-schema/src/schema.rs index 04c01f18e1d8..1e4fefbc28eb 100644 --- a/arrow-schema/src/schema.rs +++ b/arrow-schema/src/schema.rs @@ -187,7 +187,7 @@ pub type SchemaRef = Arc; pub struct Schema { /// A sequence of fields that describe the schema. pub fields: Fields, - /// A map of key-value pairs containing additional meta data. + /// A map of key-value pairs containing additional metadata. pub metadata: HashMap, } From 3a0ff5434a3d5f777b63d95006636cce7a237953 Mon Sep 17 00:00:00 2001 From: nathaniel-d-ef Date: Wed, 17 Sep 2025 21:24:29 +0200 Subject: [PATCH 05/10] Add support for configurable fingerprint strategies in Avro stream writer - Introduced `FingerprintStrategy` enum to customize fingerprinting methods, including Rabin, ConfluentSchemaId, MD5, and SHA256. - Updated stream writer to handle per-record prefix generation based on the selected strategy. - Added related unit tests for configurable fingerprint strategies. --- arrow-avro/src/schema.rs | 16 +++++++++ arrow-avro/src/writer/format.rs | 60 +++++++++++++++++++-------------- arrow-avro/src/writer/mod.rs | 58 ++++++++++--------------------- 3 files changed, 67 insertions(+), 67 deletions(-) diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index 511ba280f7ae..51e98837b745 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -316,6 +316,22 @@ pub struct Fixed<'a> { pub attributes: Attributes<'a>, } +/// Defines the strategy for generating the per-record prefix for an Avro binary stream. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum FingerprintStrategy { + /// Use the 64-bit Rabin fingerprint (default for single-object encoding). + #[default] + Rabin, + /// Use a Confluent Schema Registry 32-bit ID. + ConfluentSchemaId(u32), + #[cfg(feature = "md5")] + /// Use the 128-bit MD5 fingerprint. + MD5, + #[cfg(feature = "sha256")] + /// Use the 256-bit SHA-256 fingerprint. + SHA256, +} + /// A wrapper for an Avro schema in its JSON string representation. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AvroSchema { diff --git a/arrow-avro/src/writer/format.rs b/arrow-avro/src/writer/format.rs index 6d5ac7e1e0bf..5feed13345d0 100644 --- a/arrow-avro/src/writer/format.rs +++ b/arrow-avro/src/writer/format.rs @@ -16,7 +16,10 @@ // under the License. use crate::compression::{CompressionCodec, CODEC_METADATA_KEY}; -use crate::schema::{AvroSchema, Fingerprint, SCHEMA_METADATA_KEY, SINGLE_OBJECT_MAGIC}; +use crate::schema::{ + AvroSchema, Fingerprint, FingerprintStrategy, CONFLUENT_MAGIC, SCHEMA_METADATA_KEY, + SINGLE_OBJECT_MAGIC, +}; use crate::writer::encoder::write_long; use arrow_schema::{ArrowError, Schema}; use rand::RngCore; @@ -33,6 +36,7 @@ pub trait AvroFormat: Debug + Default { writer: &mut W, schema: &Schema, compression: Option, + fingerprint_strategy: FingerprintStrategy, ) -> Result<(), ArrowError>; /// Return the 16‑byte sync marker (OCF) or `None` (binary stream). @@ -62,6 +66,7 @@ impl AvroFormat for AvroOcfFormat { writer: &mut W, schema: &Schema, compression: Option, + _fingerprint_strategy: FingerprintStrategy, ) -> Result<(), ArrowError> { let mut rng = rand::rng(); rng.fill_bytes(&mut self.sync_marker); @@ -120,48 +125,51 @@ impl AvroFormat for AvroBinaryFormat { _writer: &mut W, schema: &Schema, compression: Option, + fingerprint_strategy: FingerprintStrategy, ) -> Result<(), ArrowError> { if compression.is_some() { return Err(ArrowError::InvalidArgumentError( - "Compression not supported for Avro binary streaming (single-object encoding)" - .to_string(), + "Compression not supported for Avro binary streaming".to_string(), )); } - if let Some(id_str) = schema.metadata().get("confluent.schema.id") { - let id: u32 = id_str.parse().map_err(|_| { - ArrowError::InvalidArgumentError(format!( - "Invalid Confluent schema ID in metadata: {id_str}" - )) - })?; - self.prefix.clear(); - self.prefix.push(0x00); - self.prefix.extend_from_slice(&id.to_be_bytes()); - return Ok(()); - } - - let avro_schema = AvroSchema::try_from(schema)?; self.prefix.clear(); - match avro_schema.fingerprint()? { - // Case 1: Confluent Schema Registry ID format. - // 1 magic byte (0x00) + 4-byte schema ID (Big Endian). - Fingerprint::Id(id) => { - self.prefix.push(0x00); + match fingerprint_strategy { + FingerprintStrategy::ConfluentSchemaId(id) => { + self.prefix.push(CONFLUENT_MAGIC[0]); self.prefix.extend_from_slice(&id.to_be_bytes()); } - - // Case 2: Standard single-object encoding with hash-based fingerprints. - // 2 magic bytes + N-byte fingerprint. - fp => { + strategy => { + // All other strategies use the single-object encoding format self.prefix.extend_from_slice(&SINGLE_OBJECT_MAGIC); + + let avro_schema = AvroSchema::try_from(schema)?; + let fp = match strategy { + FingerprintStrategy::Rabin => avro_schema.fingerprint()?, + #[cfg(feature = "md5")] + FingerprintStrategy::MD5 => AvroSchema::generate_fingerprint( + &avro_schema.schema()?, + crate::schema::FingerprintAlgorithm::MD5, + )?, + #[cfg(feature = "sha256")] + FingerprintStrategy::SHA256 => AvroSchema::generate_fingerprint( + &avro_schema.schema()?, + crate::schema::FingerprintAlgorithm::SHA256, + )?, + FingerprintStrategy::ConfluentSchemaId(_) => unreachable!(), + }; + match fp { Fingerprint::Rabin(val) => self.prefix.extend_from_slice(&val.to_le_bytes()), #[cfg(feature = "md5")] Fingerprint::MD5(val) => self.prefix.extend_from_slice(val.as_ref()), #[cfg(feature = "sha256")] Fingerprint::SHA256(val) => self.prefix.extend_from_slice(val.as_ref()), - Fingerprint::Id(_) => unreachable!(), + Fingerprint::Id(_) => return Err(ArrowError::InvalidArgumentError( + "ConfluentSchemaId strategy cannot be used with a hash-based fingerprint." + .to_string(), + )), } } } diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs index 636e8b34c39c..ad9bf3c9e6fe 100644 --- a/arrow-avro/src/writer/mod.rs +++ b/arrow-avro/src/writer/mod.rs @@ -34,7 +34,7 @@ pub mod format; use crate::codec::AvroFieldBuilder; use crate::compression::CompressionCodec; -use crate::schema::{AvroSchema, SCHEMA_METADATA_KEY}; +use crate::schema::{AvroSchema, FingerprintStrategy, SCHEMA_METADATA_KEY}; use crate::writer::encoder::{write_long, RecordEncoder, RecordEncoderBuilder}; use crate::writer::format::{AvroBinaryFormat, AvroFormat, AvroOcfFormat}; use arrow_array::RecordBatch; @@ -48,6 +48,7 @@ pub struct WriterBuilder { schema: Schema, codec: Option, capacity: usize, + fingerprint_strategy: FingerprintStrategy, } impl WriterBuilder { @@ -57,9 +58,17 @@ impl WriterBuilder { schema, codec: None, capacity: 1024, + fingerprint_strategy: FingerprintStrategy::default(), } } + /// Set the fingerprinting strategy for the stream writer. + /// This determines the per-record prefix format. + pub fn with_fingerprint_strategy(mut self, strategy: FingerprintStrategy) -> Self { + self.fingerprint_strategy = strategy; + self + } + /// Change the compression codec. pub fn with_compression(mut self, codec: Option) -> Self { self.codec = codec; @@ -90,7 +99,7 @@ impl WriterBuilder { avro_schema.clone().json_string, ); let schema = Arc::new(Schema::new_with_metadata(self.schema.fields().clone(), md)); - format.start_stream(&mut writer, &schema, self.codec)?; + format.start_stream(&mut writer, &schema, self.codec, self.fingerprint_strategy)?; let avro_root = AvroFieldBuilder::new(&avro_schema.schema()?).build()?; let encoder = RecordEncoderBuilder::new(&avro_root, schema.as_ref()).build()?; Ok(Writer { @@ -206,7 +215,7 @@ mod tests { use super::*; use crate::compression::CompressionCodec; use crate::reader::ReaderBuilder; - use crate::schema::{AvroSchema, SchemaStore}; + use crate::schema::{AvroSchema, SchemaStore, CONFLUENT_MAGIC}; use crate::test_util::arrow_test_data; use arrow_array::{ArrayRef, BinaryArray, Int32Array, Int64Array, RecordBatch}; use arrow_schema::{DataType, Field, IntervalUnit, Schema}; @@ -272,30 +281,24 @@ mod tests { #[test] fn test_stream_writer_with_id_fingerprint() -> Result<(), ArrowError> { - // 1. Schema with Confluent ID in metadata let schema_id = 42u32; - let mut metadata = std::collections::HashMap::new(); - // Assume "confluent.schema.id" is the metadata key the implementation looks for - metadata.insert("confluent.schema.id".to_string(), schema_id.to_string()); - let schema = - Schema::new(vec![Field::new("value", DataType::Int64, false)]).with_metadata(metadata); + let schema = Schema::new(vec![Field::new("value", DataType::Int64, false)]); - // 2. Batch with two rows let batch = RecordBatch::try_new( Arc::new(schema.clone()), vec![Arc::new(Int64Array::from(vec![100, 200])) as ArrayRef], )?; - // 3. Write the batch using the stream writer let buffer: Vec = Vec::new(); - let mut writer = AvroStreamWriter::new(buffer, schema)?; + let mut writer = WriterBuilder::new(schema) + .with_fingerprint_strategy(FingerprintStrategy::ConfluentSchemaId(schema_id)) + .build::<_, AvroBinaryFormat>(buffer)?; writer.write(&batch)?; let actual_bytes = writer.into_inner(); - // 4. Construct expected output manually for Confluent wire format let mut expected_bytes: Vec = Vec::new(); let prefix = { - let mut p = vec![0x00]; // Confluent magic byte + let mut p = vec![CONFLUENT_MAGIC[0]]; p.extend(&schema_id.to_be_bytes()); p }; @@ -315,33 +318,6 @@ mod tests { Ok(()) } - #[test] - fn test_stream_writer_invalid_id_fingerprint_errors() { - let mut metadata = std::collections::HashMap::new(); - metadata.insert( - "confluent.schema.id".to_string(), - "not-a-valid-id".to_string(), - ); - let schema = - Schema::new(vec![Field::new("value", DataType::Int64, false)]).with_metadata(metadata); - - let buffer: Vec = Vec::new(); - let result = AvroStreamWriter::new(buffer, schema); - - let err = result.expect_err("Writer creation should fail for invalid schema ID"); - assert!( - matches!(err, ArrowError::InvalidArgumentError(_)), - "Expected InvalidArgumentError, but got {:?}", - err - ); - assert!( - err.to_string() - .contains("Invalid Confluent schema ID in metadata"), - "Error message did not match expectation. Got: {}", - err - ); - } - #[test] fn test_ocf_writer_generates_header_and_sync() -> Result<(), ArrowError> { let batch = make_batch(); From 649825e276e2a9b4ffbd1317d46a2aa819eddb0c Mon Sep 17 00:00:00 2001 From: nathaniel-d-ef Date: Thu, 18 Sep 2025 20:03:18 +0200 Subject: [PATCH 06/10] Refactor and enhance fingerprinting in Avro writer --- arrow-avro/benches/decoder.rs | 4 +- arrow-avro/src/schema.rs | 226 +++++++++++++++++++++++++++---- arrow-avro/src/writer/encoder.rs | 25 ++-- arrow-avro/src/writer/format.rs | 78 ++--------- arrow-avro/src/writer/mod.rs | 31 +++-- 5 files changed, 248 insertions(+), 116 deletions(-) diff --git a/arrow-avro/benches/decoder.rs b/arrow-avro/benches/decoder.rs index 0ca240d12fc9..5ab0f847efcc 100644 --- a/arrow-avro/benches/decoder.rs +++ b/arrow-avro/benches/decoder.rs @@ -418,7 +418,9 @@ macro_rules! dataset { let schema = ApacheSchema::parse_str($schema_json).expect("invalid schema for generator"); let arrow_schema = AvroSchema::new($schema_json.parse().unwrap()); - let fingerprint = arrow_schema.fingerprint().expect("fingerprint failed"); + let fingerprint = arrow_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .expect("fingerprint failed"); let prefix = make_prefix(fingerprint); SIZES .iter() diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index 51e98837b745..5b3d1244fbcc 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -316,22 +316,6 @@ pub struct Fixed<'a> { pub attributes: Attributes<'a>, } -/// Defines the strategy for generating the per-record prefix for an Avro binary stream. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum FingerprintStrategy { - /// Use the 64-bit Rabin fingerprint (default for single-object encoding). - #[default] - Rabin, - /// Use a Confluent Schema Registry 32-bit ID. - ConfluentSchemaId(u32), - #[cfg(feature = "md5")] - /// Use the 128-bit MD5 fingerprint. - MD5, - #[cfg(feature = "sha256")] - /// Use the 256-bit SHA-256 fingerprint. - SHA256, -} - /// A wrapper for an Avro schema in its JSON string representation. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AvroSchema { @@ -364,9 +348,9 @@ impl AvroSchema { .map_err(|e| ArrowError::ParseError(format!("Invalid Avro schema JSON: {e}"))) } - /// Returns the Rabin fingerprint of the schema. - pub fn fingerprint(&self) -> Result { - Self::generate_fingerprint_rabin(&self.schema()?) + /// Returns the fingerprint of the schema. + pub fn fingerprint(&self, hash_type: FingerprintAlgorithm) -> Result { + Self::generate_fingerprint(&self.schema()?, hash_type) } /// Generates a fingerprint for the given `Schema` using the specified [`FingerprintAlgorithm`]. @@ -491,6 +475,54 @@ impl AvroSchema { } } +/// Defines the strategy for generating the per-record prefix for an Avro binary stream. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum FingerprintStrategy { + /// Use the 64-bit Rabin fingerprint (default for single-object encoding). + #[default] + Rabin, + /// Use a Confluent Schema Registry 32-bit ID. + Id(u32), + #[cfg(feature = "md5")] + /// Use the 128-bit MD5 fingerprint. + MD5, + #[cfg(feature = "sha256")] + /// Use the 256-bit SHA-256 fingerprint. + SHA256, +} + +impl From for FingerprintStrategy { + fn from(f: Fingerprint) -> Self { + Self::from(&f) + } +} + +impl From for FingerprintStrategy { + fn from(f: FingerprintAlgorithm) -> Self { + match f { + FingerprintAlgorithm::Rabin => FingerprintStrategy::Rabin, + FingerprintAlgorithm::None => FingerprintStrategy::Id(0), + #[cfg(feature = "md5")] + FingerprintAlgorithm::MD5 => FingerprintStrategy::MD5, + #[cfg(feature = "sha256")] + FingerprintAlgorithm::SHA256 => FingerprintStrategy::SHA256, + } + } +} + +impl From<&Fingerprint> for FingerprintStrategy { + fn from(f: &Fingerprint) -> Self { + match f { + Fingerprint::Rabin(_) => FingerprintStrategy::Rabin, + Fingerprint::Id(id) => FingerprintStrategy::Id(*id), + #[cfg(feature = "md5")] + Fingerprint::MD5(_) => FingerprintStrategy::MD5, + #[cfg(feature = "sha256")] + Fingerprint::SHA256(_) => FingerprintStrategy::SHA256, + } + } +} + /// Supported fingerprint algorithms for Avro schema identification. /// For use with Confluent Schema Registry IDs, set to None. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)] @@ -522,6 +554,25 @@ impl From<&Fingerprint> for FingerprintAlgorithm { } } +impl From for FingerprintAlgorithm { + fn from(s: FingerprintStrategy) -> Self { + Self::from(&s) + } +} + +impl From<&FingerprintStrategy> for FingerprintAlgorithm { + fn from(s: &FingerprintStrategy) -> Self { + match s { + FingerprintStrategy::Rabin => FingerprintAlgorithm::Rabin, + FingerprintStrategy::Id(_) => FingerprintAlgorithm::None, + #[cfg(feature = "md5")] + FingerprintStrategy::MD5 => FingerprintAlgorithm::MD5, + #[cfg(feature = "sha256")] + FingerprintStrategy::SHA256 => FingerprintAlgorithm::SHA256, + } + } +} + /// A schema fingerprint in one of the supported formats. /// /// This is used as the key inside `SchemaStore` `HashMap`. Each `SchemaStore` @@ -544,6 +595,38 @@ pub enum Fingerprint { SHA256([u8; 32]), } +impl From for Fingerprint { + fn from(s: FingerprintStrategy) -> Self { + Self::from(&s) + } +} + +impl From<&FingerprintStrategy> for Fingerprint { + fn from(s: &FingerprintStrategy) -> Self { + match s { + FingerprintStrategy::Rabin => Fingerprint::Rabin(0), + FingerprintStrategy::Id(id) => Fingerprint::Id(*id), + #[cfg(feature = "md5")] + FingerprintStrategy::MD5 => Fingerprint::MD5([0; 16]), + #[cfg(feature = "sha256")] + FingerprintStrategy::SHA256 => Fingerprint::SHA256([0; 32]), + } + } +} + +impl From for Fingerprint { + fn from(s: FingerprintAlgorithm) -> Self { + match s { + FingerprintAlgorithm::Rabin => Fingerprint::Rabin(0), + FingerprintAlgorithm::None => Fingerprint::Id(0), + #[cfg(feature = "md5")] + FingerprintAlgorithm::MD5 => Fingerprint::MD5([0; 16]), + #[cfg(feature = "sha256")] + FingerprintAlgorithm::SHA256 => Fingerprint::SHA256([0; 32]), + } + } +} + impl Fingerprint { /// Loads the 32-bit Schema Registry fingerprint (Confluent Schema Registry ID). /// @@ -555,6 +638,75 @@ impl Fingerprint { pub fn load_fingerprint_id(id: u32) -> Self { Fingerprint::Id(u32::from_be(id)) } + + /// Constructs a serialized prefix represented as a `Vec` based on the variant of the enum. + /// + /// This method serializes data in different formats depending on the variant of `self`: + /// - **`Id(id)`**: Uses the Confluent wire format, which includes a predefined magic header (`CONFLUENT_MAGIC`) + /// followed by the big-endian byte representation of the `id`. + /// - **`Rabin(val)`**: Uses the Avro single-object specification format. This includes a different magic header + /// (`SINGLE_OBJECT_MAGIC`) followed by the little-endian byte representation of the `val`. + /// - **`MD5(bytes)`** (optional, `md5` feature enabled): A non-standard extension that adds the + /// `SINGLE_OBJECT_MAGIC` header followed by the provided `bytes`. + /// - **`SHA256(bytes)`** (optional, `sha256` feature enabled): Similar to the `MD5` variant, this is + /// a non-standard extension that attaches the `SINGLE_OBJECT_MAGIC` header followed by the given `bytes`. + /// + /// # Returns + /// + /// A `Vec` containing the serialized prefix data. + /// + /// # Features + /// + /// - You can optionally enable the `md5` feature to include the `MD5` variant. + /// - You can optionally enable the `sha256` feature to include the `SHA256` variant. + /// + /// # Examples + /// + /// ```rust + /// use your_crate::YourEnum; + /// + /// let prefix = YourEnum::Id(12345).make_prefix(); + /// assert_eq!(prefix, /* expected Vec data */); + /// ``` + /// + /// ```rust + /// use your_crate::YourEnum; + /// + /// let prefix = YourEnum::Rabin(67890).make_prefix(); + /// assert_eq!(prefix, /* expected Vec data */); + /// ``` + pub fn make_prefix(&self) -> Vec { + match self { + Self::Id(id) => { + let mut out = Vec::with_capacity(CONFLUENT_MAGIC.len() + 4); + out.extend_from_slice(&CONFLUENT_MAGIC); + out.extend_from_slice(&id.to_be_bytes()); + out + } + Self::Rabin(val) => { + let mut out = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + 8); + out.extend_from_slice(&SINGLE_OBJECT_MAGIC); + out.extend_from_slice(&val.to_le_bytes()); + out + } + #[cfg(feature = "md5")] + Self::MD5(bytes) => { + // Non-standard extension + let mut out = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + bytes.len()); + out.extend_from_slice(&SINGLE_OBJECT_MAGIC); + out.extend_from_slice(bytes); + out + } + #[cfg(feature = "sha256")] + Self::SHA256(bytes) => { + // Non-standard extension + let mut out = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + bytes.len()); + out.extend_from_slice(&SINGLE_OBJECT_MAGIC); + out.extend_from_slice(bytes); + out + } + } + } } /// An in-memory cache of Avro schemas, indexed by their fingerprint. @@ -1650,17 +1802,25 @@ mod tests { let record_avro_schema = AvroSchema::new(serde_json::to_string(&record_schema()).unwrap()); let mut schemas: HashMap = HashMap::new(); schemas.insert( - int_avro_schema.fingerprint().unwrap(), + int_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(), int_avro_schema.clone(), ); schemas.insert( - record_avro_schema.fingerprint().unwrap(), + record_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(), record_avro_schema.clone(), ); let store = SchemaStore::try_from(schemas).unwrap(); - let int_fp = int_avro_schema.fingerprint().unwrap(); + let int_fp = int_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(); assert_eq!(store.lookup(&int_fp).cloned(), Some(int_avro_schema)); - let rec_fp = record_avro_schema.fingerprint().unwrap(); + let rec_fp = record_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(); assert_eq!(store.lookup(&rec_fp).cloned(), Some(record_avro_schema)); } @@ -1670,21 +1830,29 @@ mod tests { let record_avro_schema = AvroSchema::new(serde_json::to_string(&record_schema()).unwrap()); let mut schemas: HashMap = HashMap::new(); schemas.insert( - int_avro_schema.fingerprint().unwrap(), + int_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(), int_avro_schema.clone(), ); schemas.insert( - record_avro_schema.fingerprint().unwrap(), + record_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(), record_avro_schema.clone(), ); // Insert duplicate of int schema schemas.insert( - int_avro_schema.fingerprint().unwrap(), + int_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(), int_avro_schema.clone(), ); let store = SchemaStore::try_from(schemas).unwrap(); assert_eq!(store.schemas.len(), 2); - let int_fp = int_avro_schema.fingerprint().unwrap(); + let int_fp = int_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(); assert_eq!(store.lookup(&int_fp).cloned(), Some(int_avro_schema)); } @@ -1744,7 +1912,7 @@ mod tests { fn test_set_and_lookup_with_provided_fingerprint() { let mut store = SchemaStore::new(); let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); - let fp = schema.fingerprint().unwrap(); + let fp = schema.fingerprint(FingerprintAlgorithm::Rabin).unwrap(); let out_fp = store.set(fp, schema.clone()).unwrap(); assert_eq!(out_fp, fp); assert_eq!(store.lookup(&fp).cloned(), Some(schema)); @@ -1754,7 +1922,7 @@ mod tests { fn test_set_duplicate_same_schema_ok() { let mut store = SchemaStore::new(); let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); - let fp = schema.fingerprint().unwrap(); + let fp = schema.fingerprint(FingerprintAlgorithm::Rabin).unwrap(); let _ = store.set(fp, schema.clone()).unwrap(); let _ = store.set(fp, schema.clone()).unwrap(); assert_eq!(store.schemas.len(), 1); diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs index 6bf4ab1650e4..0f7d1af5b963 100644 --- a/arrow-avro/src/writer/encoder.rs +++ b/arrow-avro/src/writer/encoder.rs @@ -523,6 +523,7 @@ struct FieldBinding { pub struct RecordEncoderBuilder<'a> { avro_root: &'a AvroField, arrow_schema: &'a ArrowSchema, + fingerprint: Option, } impl<'a> RecordEncoderBuilder<'a> { @@ -531,9 +532,15 @@ impl<'a> RecordEncoderBuilder<'a> { Self { avro_root, arrow_schema, + fingerprint: None, } } + pub(crate) fn with_fingerprint(mut self, fingerprint: Option) -> Self { + self.fingerprint = fingerprint; + self + } + /// Build the `RecordEncoder` by walking the Avro **record** root in Avro order, /// resolving each field to an Arrow index by name. pub fn build(self) -> Result { @@ -558,7 +565,10 @@ impl<'a> RecordEncoderBuilder<'a> { )?, }); } - Ok(RecordEncoder { columns }) + Ok(RecordEncoder { + columns, + prefix: self.fingerprint.as_ref().map(|fp| fp.make_prefix()), + }) } } @@ -570,6 +580,8 @@ impl<'a> RecordEncoderBuilder<'a> { #[derive(Debug, Clone)] pub struct RecordEncoder { columns: Vec, + /// Optional pre-built, variable-length prefix written before each record. + prefix: Option>, } impl RecordEncoder { @@ -601,18 +613,13 @@ impl RecordEncoder { /// Encode a `RecordBatch` using this encoder plan. /// /// Tip: Wrap `out` in a `std::io::BufWriter` to reduce the overhead of many small writes. - pub fn encode( - &self, - out: &mut W, - batch: &RecordBatch, - prefix: Option<&[u8]>, - ) -> Result<(), ArrowError> { + pub fn encode(&self, out: &mut W, batch: &RecordBatch) -> Result<(), ArrowError> { let mut column_encoders = self.prepare_for_batch(batch)?; for row in 0..batch.num_rows() { - if let Some(prefix) = prefix { + if let Some(prefix) = &self.prefix { if !prefix.is_empty() { out.write_all(prefix).map_err(|e| { - ArrowError::IoError(format!("write single-object prefix: {e}"), e) + ArrowError::IoError(format!("Failed to write single-object prefix: {e}"), e) })?; } } diff --git a/arrow-avro/src/writer/format.rs b/arrow-avro/src/writer/format.rs index 5feed13345d0..56c67cac6af8 100644 --- a/arrow-avro/src/writer/format.rs +++ b/arrow-avro/src/writer/format.rs @@ -17,8 +17,8 @@ use crate::compression::{CompressionCodec, CODEC_METADATA_KEY}; use crate::schema::{ - AvroSchema, Fingerprint, FingerprintStrategy, CONFLUENT_MAGIC, SCHEMA_METADATA_KEY, - SINGLE_OBJECT_MAGIC, + AvroSchema, Fingerprint, FingerprintAlgorithm, FingerprintStrategy, CONFLUENT_MAGIC, + SCHEMA_METADATA_KEY, SINGLE_OBJECT_MAGIC, }; use crate::writer::encoder::write_long; use arrow_schema::{ArrowError, Schema}; @@ -28,6 +28,11 @@ use std::io::Write; /// Format abstraction implemented by each container‐level writer. pub trait AvroFormat: Debug + Default { + /// If `true`, the writer for this format will query `single_object_prefix()` + /// and write the prefix before each record. If `false`, the writer can + /// skip this step. This is a performance hint for the writer. + const NEEDS_PREFIX: bool; + /// Write any bytes required at the very beginning of the output stream /// (file header, etc.). /// Implementations **must not** write any record data. @@ -36,22 +41,10 @@ pub trait AvroFormat: Debug + Default { writer: &mut W, schema: &Schema, compression: Option, - fingerprint_strategy: FingerprintStrategy, ) -> Result<(), ArrowError>; /// Return the 16‑byte sync marker (OCF) or `None` (binary stream). fn sync_marker(&self) -> Option<&[u8; 16]>; - - /// Return the 10‑byte **Avro single‑object** prefix (`C3 01` magic + - /// little‑endian schema fingerprint) to be written **before each record**, - /// or `None` if the format does not use single‑object encoding. - /// - /// The default implementation returns `None`. `AvroBinaryFormat` overrides - /// this to return the appropriate single-object encoding prefix. - #[inline] - fn single_object_prefix(&self) -> Option<&[u8]> { - None - } } /// Avro Object Container File (OCF) format writer. @@ -61,12 +54,12 @@ pub struct AvroOcfFormat { } impl AvroFormat for AvroOcfFormat { + const NEEDS_PREFIX: bool = false; fn start_stream( &mut self, writer: &mut W, schema: &Schema, compression: Option, - _fingerprint_strategy: FingerprintStrategy, ) -> Result<(), ArrowError> { let mut rng = rand::rng(); rng.fill_bytes(&mut self.sync_marker); @@ -114,18 +107,15 @@ impl AvroFormat for AvroOcfFormat { /// See: /// See: #[derive(Debug, Default)] -pub struct AvroBinaryFormat { - /// Pre-built, variable-length prefix written before each record. - prefix: Vec, -} +pub struct AvroBinaryFormat {} impl AvroFormat for AvroBinaryFormat { + const NEEDS_PREFIX: bool = true; fn start_stream( &mut self, _writer: &mut W, schema: &Schema, compression: Option, - fingerprint_strategy: FingerprintStrategy, ) -> Result<(), ArrowError> { if compression.is_some() { return Err(ArrowError::InvalidArgumentError( @@ -133,60 +123,12 @@ impl AvroFormat for AvroBinaryFormat { )); } - self.prefix.clear(); - - match fingerprint_strategy { - FingerprintStrategy::ConfluentSchemaId(id) => { - self.prefix.push(CONFLUENT_MAGIC[0]); - self.prefix.extend_from_slice(&id.to_be_bytes()); - } - strategy => { - // All other strategies use the single-object encoding format - self.prefix.extend_from_slice(&SINGLE_OBJECT_MAGIC); - - let avro_schema = AvroSchema::try_from(schema)?; - let fp = match strategy { - FingerprintStrategy::Rabin => avro_schema.fingerprint()?, - #[cfg(feature = "md5")] - FingerprintStrategy::MD5 => AvroSchema::generate_fingerprint( - &avro_schema.schema()?, - crate::schema::FingerprintAlgorithm::MD5, - )?, - #[cfg(feature = "sha256")] - FingerprintStrategy::SHA256 => AvroSchema::generate_fingerprint( - &avro_schema.schema()?, - crate::schema::FingerprintAlgorithm::SHA256, - )?, - FingerprintStrategy::ConfluentSchemaId(_) => unreachable!(), - }; - - match fp { - Fingerprint::Rabin(val) => self.prefix.extend_from_slice(&val.to_le_bytes()), - #[cfg(feature = "md5")] - Fingerprint::MD5(val) => self.prefix.extend_from_slice(val.as_ref()), - #[cfg(feature = "sha256")] - Fingerprint::SHA256(val) => self.prefix.extend_from_slice(val.as_ref()), - Fingerprint::Id(_) => return Err(ArrowError::InvalidArgumentError( - "ConfluentSchemaId strategy cannot be used with a hash-based fingerprint." - .to_string(), - )), - } - } - } Ok(()) } fn sync_marker(&self) -> Option<&[u8; 16]> { None } - - fn single_object_prefix(&self) -> Option<&[u8]> { - if self.prefix.is_empty() { - None - } else { - Some(&self.prefix) - } - } } #[inline] diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs index ad9bf3c9e6fe..be62ba53aaa9 100644 --- a/arrow-avro/src/writer/mod.rs +++ b/arrow-avro/src/writer/mod.rs @@ -34,7 +34,9 @@ pub mod format; use crate::codec::AvroFieldBuilder; use crate::compression::CompressionCodec; -use crate::schema::{AvroSchema, FingerprintStrategy, SCHEMA_METADATA_KEY}; +use crate::schema::{ + AvroSchema, Fingerprint, FingerprintAlgorithm, FingerprintStrategy, SCHEMA_METADATA_KEY, +}; use crate::writer::encoder::{write_long, RecordEncoder, RecordEncoderBuilder}; use crate::writer::format::{AvroBinaryFormat, AvroFormat, AvroOcfFormat}; use arrow_array::RecordBatch; @@ -93,15 +95,28 @@ impl WriterBuilder { Some(json) => AvroSchema::new(json.clone()), None => AvroSchema::try_from(&self.schema)?, }; + + let maybe_fingerprint = if F::NEEDS_PREFIX { + match self.fingerprint_strategy { + FingerprintStrategy::Id(id) => Some(Fingerprint::Id(id)), + strategy => Some(avro_schema.fingerprint(FingerprintAlgorithm::from(strategy))?), + } + } else { + None + }; + let maybe_prefix = maybe_fingerprint.as_ref().map(|fp| fp.make_prefix()); + let mut md = self.schema.metadata().clone(); md.insert( SCHEMA_METADATA_KEY.to_string(), avro_schema.clone().json_string, ); let schema = Arc::new(Schema::new_with_metadata(self.schema.fields().clone(), md)); - format.start_stream(&mut writer, &schema, self.codec, self.fingerprint_strategy)?; + format.start_stream(&mut writer, &schema, self.codec)?; let avro_root = AvroFieldBuilder::new(&avro_schema.schema()?).build()?; - let encoder = RecordEncoderBuilder::new(&avro_root, schema.as_ref()).build()?; + let encoder = RecordEncoderBuilder::new(&avro_root, schema.as_ref()) + .with_fingerprint(maybe_fingerprint) + .build()?; Ok(Writer { writer, schema, @@ -186,8 +201,7 @@ impl Writer { fn write_ocf_block(&mut self, batch: &RecordBatch, sync: &[u8; 16]) -> Result<(), ArrowError> { let mut buf = Vec::::with_capacity(1024); - self.encoder - .encode(&mut buf, batch, self.format.single_object_prefix())?; + self.encoder.encode(&mut buf, batch)?; let encoded = match self.compression { Some(codec) => codec.compress(&buf)?, None => buf, @@ -204,8 +218,7 @@ impl Writer { } fn write_stream(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { - self.encoder - .encode(&mut self.writer, batch, self.format.single_object_prefix())?; + self.encoder.encode(&mut self.writer, batch)?; Ok(()) } } @@ -247,7 +260,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let avro_schema = AvroSchema::try_from(&schema)?; - let fingerprint = avro_schema.fingerprint()?; + let fingerprint = avro_schema.fingerprint(FingerprintAlgorithm::Rabin)?; let mut expected_prefix = Vec::from(crate::schema::SINGLE_OBJECT_MAGIC); match fingerprint { crate::schema::Fingerprint::Rabin(val) => expected_prefix.extend(val.to_le_bytes()), @@ -291,7 +304,7 @@ mod tests { let buffer: Vec = Vec::new(); let mut writer = WriterBuilder::new(schema) - .with_fingerprint_strategy(FingerprintStrategy::ConfluentSchemaId(schema_id)) + .with_fingerprint_strategy(FingerprintStrategy::Id(schema_id)) .build::<_, AvroBinaryFormat>(buffer)?; writer.write(&batch)?; let actual_bytes = writer.into_inner(); From c5ff88ab323d9fa2daea585d40265e225701ed51 Mon Sep 17 00:00:00 2001 From: nathaniel-d-ef Date: Thu, 18 Sep 2025 20:15:49 +0200 Subject: [PATCH 07/10] Remove verbose examples from `make_prefix` method documentation. --- arrow-avro/src/schema.rs | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index 5b3d1244fbcc..f7af4c3cbe72 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -659,22 +659,7 @@ impl Fingerprint { /// /// - You can optionally enable the `md5` feature to include the `MD5` variant. /// - You can optionally enable the `sha256` feature to include the `SHA256` variant. - /// - /// # Examples - /// - /// ```rust - /// use your_crate::YourEnum; - /// - /// let prefix = YourEnum::Id(12345).make_prefix(); - /// assert_eq!(prefix, /* expected Vec data */); - /// ``` - /// - /// ```rust - /// use your_crate::YourEnum; - /// - /// let prefix = YourEnum::Rabin(67890).make_prefix(); - /// assert_eq!(prefix, /* expected Vec data */); - /// ``` + /// pub fn make_prefix(&self) -> Vec { match self { Self::Id(id) => { From 286da05d8fff259b02781a9c30d551758d846e91 Mon Sep 17 00:00:00 2001 From: nathaniel-d-ef Date: Thu, 18 Sep 2025 20:16:26 +0200 Subject: [PATCH 08/10] Trim whitespace in `make_prefix` method documentation. --- arrow-avro/src/schema.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index f7af4c3cbe72..f7b39db8e925 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -659,7 +659,7 @@ impl Fingerprint { /// /// - You can optionally enable the `md5` feature to include the `MD5` variant. /// - You can optionally enable the `sha256` feature to include the `SHA256` variant. - /// + /// pub fn make_prefix(&self) -> Vec { match self { Self::Id(id) => { From 693360beaa6eee761d20a5469636bb8b9a4abd31 Mon Sep 17 00:00:00 2001 From: nathaniel-d-ef Date: Thu, 18 Sep 2025 21:15:37 +0200 Subject: [PATCH 09/10] Refactor Avro writer to use stack-allocated `Prefix` for encoding. --- arrow-avro/src/schema.rs | 65 +++++++++++++++++++++----------- arrow-avro/src/writer/encoder.rs | 30 +++++++++------ 2 files changed, 62 insertions(+), 33 deletions(-) diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index f7b39db8e925..58c4cc8e8e9d 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -33,6 +33,10 @@ pub const SINGLE_OBJECT_MAGIC: [u8; 2] = [0xC3, 0x01]; /// The Confluent "magic" byte (`0x00`) pub const CONFLUENT_MAGIC: [u8; 1] = [0x00]; +/// The maximum possible length of a prefix. +/// SHA256 (32) + single-object magic (2) +pub const MAX_PREFIX_LEN: usize = 34; + /// The metadata key used for storing the JSON encoded [`Schema`] pub const SCHEMA_METADATA_KEY: &str = "avro.schema"; @@ -475,6 +479,20 @@ impl AvroSchema { } } +/// A stack-allocated, fixed-size buffer for the prefix. +#[derive(Debug, Copy, Clone)] +pub struct Prefix { + buf: [u8; MAX_PREFIX_LEN], + len: u8, +} + +impl Prefix { + #[inline] + pub(crate) fn as_slice(&self) -> &[u8] { + &self.buf[..self.len as usize] + } +} + /// Defines the strategy for generating the per-record prefix for an Avro binary stream. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum FingerprintStrategy { @@ -653,43 +671,48 @@ impl Fingerprint { /// /// # Returns /// - /// A `Vec` containing the serialized prefix data. + /// A `Prefix` containing the serialized prefix data. /// /// # Features /// /// - You can optionally enable the `md5` feature to include the `MD5` variant. /// - You can optionally enable the `sha256` feature to include the `SHA256` variant. /// - pub fn make_prefix(&self) -> Vec { - match self { + pub fn make_prefix(&self) -> Prefix { + let mut buf = [0u8; MAX_PREFIX_LEN]; + let len = match self { Self::Id(id) => { - let mut out = Vec::with_capacity(CONFLUENT_MAGIC.len() + 4); - out.extend_from_slice(&CONFLUENT_MAGIC); - out.extend_from_slice(&id.to_be_bytes()); - out + let prefix_slice = &mut buf[..5]; + prefix_slice[..1].copy_from_slice(&CONFLUENT_MAGIC); + prefix_slice[1..5].copy_from_slice(&id.to_be_bytes()); + 5 } Self::Rabin(val) => { - let mut out = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + 8); - out.extend_from_slice(&SINGLE_OBJECT_MAGIC); - out.extend_from_slice(&val.to_le_bytes()); - out + let prefix_slice = &mut buf[..10]; + prefix_slice[..2].copy_from_slice(&SINGLE_OBJECT_MAGIC); + prefix_slice[2..10].copy_from_slice(&val.to_le_bytes()); + 10 } #[cfg(feature = "md5")] Self::MD5(bytes) => { - // Non-standard extension - let mut out = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + bytes.len()); - out.extend_from_slice(&SINGLE_OBJECT_MAGIC); - out.extend_from_slice(bytes); - out + const LEN: usize = 2 + 16; + let prefix_slice = &mut buf[..LEN]; + prefix_slice[..2].copy_from_slice(&SINGLE_OBJECT_MAGIC); + prefix_slice[2..LEN].copy_from_slice(bytes); + LEN } #[cfg(feature = "sha256")] Self::SHA256(bytes) => { - // Non-standard extension - let mut out = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + bytes.len()); - out.extend_from_slice(&SINGLE_OBJECT_MAGIC); - out.extend_from_slice(bytes); - out + const LEN: usize = 2 + 32; + let prefix_slice = &mut buf[..LEN]; + prefix_slice[..2].copy_from_slice(&SINGLE_OBJECT_MAGIC); + prefix_slice[2..LEN].copy_from_slice(bytes); + LEN } + }; + Prefix { + buf, + len: len as u8, } } } diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs index 0f7d1af5b963..518179530f3d 100644 --- a/arrow-avro/src/writer/encoder.rs +++ b/arrow-avro/src/writer/encoder.rs @@ -18,7 +18,7 @@ //! Avro Encoder for Arrow types. use crate::codec::{AvroDataType, AvroField, Codec}; -use crate::schema::{Fingerprint, Nullability}; +use crate::schema::{Fingerprint, Nullability, Prefix}; use arrow_array::cast::AsArray; use arrow_array::types::{ ArrowPrimitiveType, Float32Type, Float64Type, Int32Type, Int64Type, IntervalDayTimeType, @@ -567,7 +567,7 @@ impl<'a> RecordEncoderBuilder<'a> { } Ok(RecordEncoder { columns, - prefix: self.fingerprint.as_ref().map(|fp| fp.make_prefix()), + prefix: self.fingerprint.map(|fp| fp.make_prefix()), }) } } @@ -581,7 +581,7 @@ impl<'a> RecordEncoderBuilder<'a> { pub struct RecordEncoder { columns: Vec, /// Optional pre-built, variable-length prefix written before each record. - prefix: Option>, + prefix: Option, } impl RecordEncoder { @@ -615,17 +615,23 @@ impl RecordEncoder { /// Tip: Wrap `out` in a `std::io::BufWriter` to reduce the overhead of many small writes. pub fn encode(&self, out: &mut W, batch: &RecordBatch) -> Result<(), ArrowError> { let mut column_encoders = self.prepare_for_batch(batch)?; - for row in 0..batch.num_rows() { - if let Some(prefix) = &self.prefix { - if !prefix.is_empty() { - out.write_all(prefix).map_err(|e| { - ArrowError::IoError(format!("Failed to write single-object prefix: {e}"), e) - })?; + let n = batch.num_rows(); + match self.prefix { + Some(prefix) => { + for row in 0..n { + out.write_all(prefix.as_slice()) + .map_err(|e| ArrowError::IoError(format!("write prefix: {e}"), e))?; + for enc in column_encoders.iter_mut() { + enc.encode(out, row)?; + } } } - - for encoder in column_encoders.iter_mut() { - encoder.encode(out, row)?; + None => { + for row in 0..n { + for enc in column_encoders.iter_mut() { + enc.encode(out, row)?; + } + } } } Ok(()) From f551b7749677cc188880866eaea0ea4e98d28d99 Mon Sep 17 00:00:00 2001 From: nathaniel-d-ef Date: Mon, 22 Sep 2025 14:53:49 +0200 Subject: [PATCH 10/10] PR comments: Refactor `WriterBuilder` to use optional `FingerprintStrategy` and update prefix handling logic. - Transitioned `FingerprintStrategy` from mandatory to optional in `WriterBuilder`, defaulting to `Rabin` when unspecified. - Simplified prefix creation with a new `write_prefix` utility function. - Enhanced unit tests to verify backward compatibility and schema decoding for stream writers. --- arrow-avro/src/schema.rs | 48 +++++-------- arrow-avro/src/writer/format.rs | 2 +- arrow-avro/src/writer/mod.rs | 122 +++++++++++++++----------------- 3 files changed, 78 insertions(+), 94 deletions(-) diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index f9ae4bd1f97b..88c47aadea7a 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -682,42 +682,30 @@ impl Fingerprint { pub fn make_prefix(&self) -> Prefix { let mut buf = [0u8; MAX_PREFIX_LEN]; let len = match self { - Self::Id(id) => { - let prefix_slice = &mut buf[..5]; - prefix_slice[..1].copy_from_slice(&CONFLUENT_MAGIC); - prefix_slice[1..5].copy_from_slice(&id.to_be_bytes()); - 5 - } - Self::Rabin(val) => { - let prefix_slice = &mut buf[..10]; - prefix_slice[..2].copy_from_slice(&SINGLE_OBJECT_MAGIC); - prefix_slice[2..10].copy_from_slice(&val.to_le_bytes()); - 10 - } + Self::Id(val) => write_prefix(&mut buf, &CONFLUENT_MAGIC, &val.to_be_bytes()), + Self::Rabin(val) => write_prefix(&mut buf, &SINGLE_OBJECT_MAGIC, &val.to_le_bytes()), #[cfg(feature = "md5")] - Self::MD5(bytes) => { - const LEN: usize = 2 + 16; - let prefix_slice = &mut buf[..LEN]; - prefix_slice[..2].copy_from_slice(&SINGLE_OBJECT_MAGIC); - prefix_slice[2..LEN].copy_from_slice(bytes); - LEN - } + Self::MD5(val) => write_prefix(&mut buf, &SINGLE_OBJECT_MAGIC, val), #[cfg(feature = "sha256")] - Self::SHA256(bytes) => { - const LEN: usize = 2 + 32; - let prefix_slice = &mut buf[..LEN]; - prefix_slice[..2].copy_from_slice(&SINGLE_OBJECT_MAGIC); - prefix_slice[2..LEN].copy_from_slice(bytes); - LEN - } + Self::SHA256(val) => write_prefix(&mut buf, &SINGLE_OBJECT_MAGIC, val), }; - Prefix { - buf, - len: len as u8, - } + Prefix { buf, len } } } +fn write_prefix( + buf: &mut [u8; MAX_PREFIX_LEN], + magic: &[u8; MAGIC_LEN], + payload: &[u8; PAYLOAD_LEN], +) -> u8 { + debug_assert!(MAGIC_LEN + PAYLOAD_LEN <= MAX_PREFIX_LEN); + let total = MAGIC_LEN + PAYLOAD_LEN; + let prefix_slice = &mut buf[..total]; + prefix_slice[..MAGIC_LEN].copy_from_slice(magic); + prefix_slice[MAGIC_LEN..total].copy_from_slice(payload); + total as u8 +} + /// An in-memory cache of Avro schemas, indexed by their fingerprint. /// /// `SchemaStore` provides a mechanism to store and retrieve Avro schemas efficiently. diff --git a/arrow-avro/src/writer/format.rs b/arrow-avro/src/writer/format.rs index 56c67cac6af8..a6ddba38d24b 100644 --- a/arrow-avro/src/writer/format.rs +++ b/arrow-avro/src/writer/format.rs @@ -114,7 +114,7 @@ impl AvroFormat for AvroBinaryFormat { fn start_stream( &mut self, _writer: &mut W, - schema: &Schema, + _schema: &Schema, compression: Option, ) -> Result<(), ArrowError> { if compression.is_some() { diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs index be62ba53aaa9..7a7b0d283750 100644 --- a/arrow-avro/src/writer/mod.rs +++ b/arrow-avro/src/writer/mod.rs @@ -50,7 +50,7 @@ pub struct WriterBuilder { schema: Schema, codec: Option, capacity: usize, - fingerprint_strategy: FingerprintStrategy, + fingerprint_strategy: Option, } impl WriterBuilder { @@ -60,14 +60,14 @@ impl WriterBuilder { schema, codec: None, capacity: 1024, - fingerprint_strategy: FingerprintStrategy::default(), + fingerprint_strategy: None, } } /// Set the fingerprinting strategy for the stream writer. /// This determines the per-record prefix format. pub fn with_fingerprint_strategy(mut self, strategy: FingerprintStrategy) -> Self { - self.fingerprint_strategy = strategy; + self.fingerprint_strategy = Some(strategy); self } @@ -98,13 +98,18 @@ impl WriterBuilder { let maybe_fingerprint = if F::NEEDS_PREFIX { match self.fingerprint_strategy { - FingerprintStrategy::Id(id) => Some(Fingerprint::Id(id)), - strategy => Some(avro_schema.fingerprint(FingerprintAlgorithm::from(strategy))?), + Some(FingerprintStrategy::Id(id)) => Some(Fingerprint::Id(id)), + Some(strategy) => { + Some(avro_schema.fingerprint(FingerprintAlgorithm::from(strategy))?) + } + None => Some( + avro_schema + .fingerprint(FingerprintAlgorithm::from(FingerprintStrategy::Rabin))?, + ), } } else { None }; - let maybe_prefix = maybe_fingerprint.as_ref().map(|fp| fp.make_prefix()); let mut md = self.schema.metadata().clone(); md.insert( @@ -256,78 +261,69 @@ mod tests { } #[test] - fn test_stream_writer_writes_prefix_per_row() -> Result<(), ArrowError> { + fn test_stream_writer_writes_prefix_per_row_rt() -> Result<(), ArrowError> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let avro_schema = AvroSchema::try_from(&schema)?; - - let fingerprint = avro_schema.fingerprint(FingerprintAlgorithm::Rabin)?; - let mut expected_prefix = Vec::from(crate::schema::SINGLE_OBJECT_MAGIC); - match fingerprint { - crate::schema::Fingerprint::Rabin(val) => expected_prefix.extend(val.to_le_bytes()), - _ => panic!("Expected Rabin fingerprint for default stream writer"), - } - let batch = RecordBatch::try_new( Arc::new(schema.clone()), vec![Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef], )?; - - let buffer: Vec = Vec::new(); - let mut writer = AvroStreamWriter::new(buffer, schema)?; + let buf: Vec = Vec::new(); + let mut writer = AvroStreamWriter::new(buf, schema.clone())?; writer.write(&batch)?; - let actual_bytes = writer.into_inner(); - - let mut expected_bytes = Vec::new(); - // Row 1: prefix + zig-zag encoded(10) - expected_bytes.extend(&expected_prefix); - expected_bytes.push(0x14); - // Row 2: prefix + zig-zag encoded(20) - expected_bytes.extend(&expected_prefix); - expected_bytes.push(0x28); - - assert_eq!( - actual_bytes, expected_bytes, - "Stream writer output did not match expected prefix-per-row format" - ); + let encoded = writer.into_inner(); + let mut store = SchemaStore::new(); // Rabin by default + let avro_schema = AvroSchema::try_from(&schema)?; + let _fp = store.register(avro_schema)?; + let mut decoder = ReaderBuilder::new() + .with_writer_schema_store(store) + .build_decoder()?; + let _consumed = decoder.decode(&encoded)?; + let decoded = decoder + .flush()? + .expect("expected at least one batch from decoder"); + assert_eq!(decoded.num_columns(), 1); + assert_eq!(decoded.num_rows(), 2); + let col = decoded + .column(0) + .as_any() + .downcast_ref::() + .expect("int column"); + assert_eq!(col, &Int32Array::from(vec![10, 20])); Ok(()) } #[test] - fn test_stream_writer_with_id_fingerprint() -> Result<(), ArrowError> { - let schema_id = 42u32; - let schema = Schema::new(vec![Field::new("value", DataType::Int64, false)]); - + fn test_stream_writer_with_id_fingerprint_rt() -> Result<(), ArrowError> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let batch = RecordBatch::try_new( Arc::new(schema.clone()), - vec![Arc::new(Int64Array::from(vec![100, 200])) as ArrayRef], + vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef], )?; - - let buffer: Vec = Vec::new(); - let mut writer = WriterBuilder::new(schema) + let schema_id: u32 = 42; + let mut writer = WriterBuilder::new(schema.clone()) .with_fingerprint_strategy(FingerprintStrategy::Id(schema_id)) - .build::<_, AvroBinaryFormat>(buffer)?; + .build::<_, AvroBinaryFormat>(Vec::new())?; writer.write(&batch)?; - let actual_bytes = writer.into_inner(); - - let mut expected_bytes: Vec = Vec::new(); - let prefix = { - let mut p = vec![CONFLUENT_MAGIC[0]]; - p.extend(&schema_id.to_be_bytes()); - p - }; - - // Row 1: prefix + zig-zag encoded(100) -> 200 -> [0xC8, 0x01] - expected_bytes.extend(&prefix); - expected_bytes.extend(&[0xC8, 0x01]); - // Row 2: prefix + zig-zag encoded(200) -> 400 -> [0x90, 0x03] - expected_bytes.extend(&prefix); - expected_bytes.extend(&[0x90, 0x03]); - - // 5. Assert - assert_eq!( - actual_bytes, expected_bytes, - "Stream writer output for Confluent ID did not match expected format" - ); + let encoded = writer.into_inner(); + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::None); + let avro_schema = AvroSchema::try_from(&schema)?; + let _ = store.set(Fingerprint::Id(schema_id), avro_schema)?; + let mut decoder = ReaderBuilder::new() + .with_writer_schema_store(store) + .build_decoder()?; + let _ = decoder.decode(&encoded)?; + let decoded = decoder + .flush()? + .expect("expected at least one batch from decoder"); + assert_eq!(decoded.num_columns(), 1); + assert_eq!(decoded.num_rows(), 3); + let col = decoded + .column(0) + .as_any() + .downcast_ref::() + .expect("int column"); + assert_eq!(col, &Int32Array::from(vec![1, 2, 3])); Ok(()) }