From 976730aa3ae410b10a80a57dd0aa593e5839a242 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 15 Sep 2025 00:37:59 -0500 Subject: [PATCH 01/22] Added decoder support for Union types and Union resolution. --- arrow-avro/src/codec.rs | 518 +++++++- arrow-avro/src/reader/mod.rs | 1004 +++++++++++++++- arrow-avro/src/reader/record.rs | 1521 +++++++++++++++++++++--- arrow-avro/src/schema.rs | 187 ++- arrow-avro/test/data/README.md | 61 +- arrow-avro/test/data/union_fields.avro | Bin 0 -> 3430 bytes 6 files changed, 3054 insertions(+), 237 deletions(-) create mode 100644 arrow-avro/test/data/union_fields.avro diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index cf0276f0a25d..c968c1a34cb1 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -16,20 +16,21 @@ // under the License. use crate::schema::{ - Array, Attributes, AvroSchema, ComplexType, Enum, Fixed, Map, Nullability, PrimitiveType, - Record, Schema, Type, TypeName, AVRO_ENUM_SYMBOLS_METADATA_KEY, + make_full_name, Array, Attributes, AvroSchema, ComplexType, Enum, Fixed, Map, Nullability, + PrimitiveType, Record, Schema, Type, TypeName, AVRO_ENUM_SYMBOLS_METADATA_KEY, AVRO_FIELD_DEFAULT_METADATA_KEY, AVRO_ROOT_RECORD_DEFAULT_NAME, }; use arrow_schema::{ - ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, - DECIMAL256_MAX_PRECISION, + ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, UnionFields, UnionMode, + DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; #[cfg(feature = "small_decimals")] use arrow_schema::{DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION}; use indexmap::IndexMap; use serde_json::Value; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use strum_macros::AsRefStr; /// Contains information about how to resolve differences between a writer's and a reader's schema. #[derive(Debug, Clone, PartialEq)] @@ -42,6 +43,8 @@ pub(crate) enum ResolutionInfo { EnumMapping(EnumMapping), /// Provides resolution information for record fields. Record(ResolvedRecord), + /// Provides mapping and shape info for resolving unions. + Union(ResolvedUnion), } /// Represents a literal Avro value. @@ -92,8 +95,10 @@ pub struct ResolvedRecord { /// /// Schema resolution may require promoting a writer's data type to a reader's data type. /// For example, an `int` can be promoted to a `long`, `float`, or `double`. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum Promotion { + /// Direct read with no data type promotion. + Direct, /// Promotes an `int` to a `long`. IntToLong, /// Promotes an `int` to a `float`. @@ -112,6 +117,18 @@ pub(crate) enum Promotion { BytesToString, } +/// Information required to resolve a writer union against a reader union (or single type). +#[derive(Debug, Clone, PartialEq)] +pub struct ResolvedUnion { + /// For each writer branch index, the reader branch index and how to read it. + /// `None` means the writer branch doesn't resolve against the reader. + pub(crate) writer_to_reader: Arc<[Option<(usize, Promotion)>]>, + /// Whether the writer schema at this site is a union + pub(crate) writer_is_union: bool, + /// Whether the reader schema at this site is a union + pub(crate) reader_is_union: bool, +} + /// Holds the mapping information for resolving Avro enums. /// /// When resolving schemas, the writer's enum symbols must be mapped to the reader's symbols. @@ -267,6 +284,11 @@ impl AvroDataType { if default_json.is_null() { return match self.codec() { Codec::Null => Ok(AvroLiteral::Null), + Codec::Union(encodings, _, _) if !encodings.is_empty() + && matches!(encodings[0].codec(), Codec::Null) => + { + Ok(AvroLiteral::Null) + } _ if self.nullability() == Some(Nullability::NullFirst) => Ok(AvroLiteral::Null), _ => Err(ArrowError::SchemaError( "JSON null default is only valid for `null` type or for a union whose first branch is `null`" @@ -401,6 +423,14 @@ impl AvroDataType { )) } }, + Codec::Union(encodings, _, _) => { + if encodings.is_empty() { + return Err(ArrowError::SchemaError( + "Union with no branches cannot have a default".to_string(), + )); + } + encodings[0].parse_default_literal(default_json)? + } }; Ok(lit) } @@ -635,6 +665,8 @@ pub enum Codec { Map(Arc), /// Represents Avro duration logical type, maps to Arrow's Interval(IntervalUnit::MonthDayNano) data type Interval, + /// Represents Avro union type, maps to Arrow's Union data type + Union(Arc<[AvroDataType]>, UnionFields, UnionMode), } impl Codec { @@ -708,8 +740,42 @@ impl Codec { false, ) } + Self::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode), + } + } + + /// Converts a string codec to use Utf8View if requested + /// + /// The conversion only happens if both: + /// 1. `use_utf8view` is true + /// 2. The codec is currently `Utf8` + /// + /// # Example + /// ``` + /// # use arrow_avro::codec::Codec; + /// let utf8_codec1 = Codec::Utf8; + /// let utf8_codec2 = Codec::Utf8; + /// + /// // Convert to Utf8View + /// let view_codec = utf8_codec1.with_utf8view(true); + /// assert!(matches!(view_codec, Codec::Utf8View)); + /// + /// // Don't convert if use_utf8view is false + /// let unchanged_codec = utf8_codec2.with_utf8view(false); + /// assert!(matches!(unchanged_codec, Codec::Utf8)); + /// ``` + pub fn with_utf8view(self, use_utf8view: bool) -> Self { + if use_utf8view && matches!(self, Self::Utf8) { + Self::Utf8View + } else { + self } } + + #[inline] + fn union_field_name(&self) -> String { + UnionFieldKind::from(self).as_ref().to_owned() + } } impl From for Codec { @@ -804,36 +870,79 @@ fn parse_decimal_attributes( Ok((precision, scale, size)) } -impl Codec { - /// Converts a string codec to use Utf8View if requested - /// - /// The conversion only happens if both: - /// 1. `use_utf8view` is true - /// 2. The codec is currently `Utf8` - /// - /// # Example - /// ``` - /// # use arrow_avro::codec::Codec; - /// let utf8_codec1 = Codec::Utf8; - /// let utf8_codec2 = Codec::Utf8; - /// - /// // Convert to Utf8View - /// let view_codec = utf8_codec1.with_utf8view(true); - /// assert!(matches!(view_codec, Codec::Utf8View)); - /// - /// // Don't convert if use_utf8view is false - /// let unchanged_codec = utf8_codec2.with_utf8view(false); - /// assert!(matches!(unchanged_codec, Codec::Utf8)); - /// ``` - pub fn with_utf8view(self, use_utf8view: bool) -> Self { - if use_utf8view && matches!(self, Self::Utf8) { - Self::Utf8View - } else { - self +#[derive(Debug, Clone, Copy, PartialEq, Eq, AsRefStr)] +#[strum(serialize_all = "snake_case")] +enum UnionFieldKind { + Null, + Boolean, + Int, + Long, + Float, + Double, + Bytes, + String, + Date, + TimeMillis, + TimeMicros, + TimestampMillisUtc, + TimestampMillisLocal, + TimestampMicrosUtc, + TimestampMicrosLocal, + Duration, + Fixed, + Decimal, + Enum, + Array, + Record, + Map, + Uuid, + Union, +} + +impl From<&Codec> for UnionFieldKind { + fn from(c: &Codec) -> Self { + match c { + Codec::Null => Self::Null, + Codec::Boolean => Self::Boolean, + Codec::Int32 => Self::Int, + Codec::Int64 => Self::Long, + Codec::Float32 => Self::Float, + Codec::Float64 => Self::Double, + Codec::Binary => Self::Bytes, + Codec::Utf8 | Codec::Utf8View => Self::String, + Codec::Date32 => Self::Date, + Codec::TimeMillis => Self::TimeMillis, + Codec::TimeMicros => Self::TimeMicros, + Codec::TimestampMillis(true) => Self::TimestampMillisUtc, + Codec::TimestampMillis(false) => Self::TimestampMillisLocal, + Codec::TimestampMicros(true) => Self::TimestampMicrosUtc, + Codec::TimestampMicros(false) => Self::TimestampMicrosLocal, + Codec::Interval => Self::Duration, + Codec::Fixed(_) => Self::Fixed, + Codec::Decimal(..) => Self::Decimal, + Codec::Enum(_) => Self::Enum, + Codec::List(_) => Self::Array, + Codec::Struct(_) => Self::Record, + Codec::Map(_) => Self::Map, + Codec::Uuid => Self::Uuid, + Codec::Union(..) => Self::Union, } } } +#[inline] +fn build_union_fields(encodings: &[AvroDataType]) -> UnionFields { + let arrow_fields: Vec = encodings + .iter() + .map(|encoding| { + let name = encoding.codec().union_field_name(); + encoding.field_with_name(&name) + }) + .collect(); + let type_ids: Vec = (0..arrow_fields.len()).map(|i| i as i8).collect(); + UnionFields::new(type_ids, arrow_fields) +} + /// Resolves Avro type names to [`AvroDataType`] /// /// See @@ -915,6 +1024,76 @@ fn nullable_union_variants<'x, 'y>( } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum UnionBranchKey { + Named(String), + Primitive(PrimitiveType), + Array, + Map, +} + +fn branch_key_of<'a>(s: &Schema<'a>, enclosing_ns: Option<&'a str>) -> Option { + match s { + // Primitives + Schema::TypeName(TypeName::Primitive(p)) => Some(UnionBranchKey::Primitive(*p)), + Schema::Type(Type { + r#type: TypeName::Primitive(p), + .. + }) => Some(UnionBranchKey::Primitive(*p)), + // Named references + Schema::TypeName(TypeName::Ref(name)) => { + let (full, _) = make_full_name(name, None, enclosing_ns); + Some(UnionBranchKey::Named(full)) + } + Schema::Type(Type { + r#type: TypeName::Ref(name), + .. + }) => { + let (full, _) = make_full_name(name, None, enclosing_ns); + Some(UnionBranchKey::Named(full)) + } + // Complex non‑named + Schema::Complex(ComplexType::Array(_)) => Some(UnionBranchKey::Array), + Schema::Complex(ComplexType::Map(_)) => Some(UnionBranchKey::Map), + // Inline named definitions + Schema::Complex(ComplexType::Record(r)) => { + let (full, _) = make_full_name(r.name, r.namespace, enclosing_ns); + Some(UnionBranchKey::Named(full)) + } + Schema::Complex(ComplexType::Enum(e)) => { + let (full, _) = make_full_name(e.name, e.namespace, enclosing_ns); + Some(UnionBranchKey::Named(full)) + } + Schema::Complex(ComplexType::Fixed(f)) => { + let (full, _) = make_full_name(f.name, f.namespace, enclosing_ns); + Some(UnionBranchKey::Named(full)) + } + // Unions are validated separately (and disallowed as immediate branches) + Schema::Union(_) => None, + } +} + +fn union_first_duplicate<'a>( + branches: &'a [Schema<'a>], + enclosing_ns: Option<&'a str>, +) -> Option { + let mut seen: HashSet = HashSet::with_capacity(branches.len()); + for b in branches { + if let Some(key) = branch_key_of(b, enclosing_ns) { + if !seen.insert(key.clone()) { + let msg = match key { + UnionBranchKey::Named(full) => format!("named type {full}"), + UnionBranchKey::Primitive(p) => format!("primitive {}", p.as_ref()), + UnionBranchKey::Array => "array".to_string(), + UnionBranchKey::Map => "map".to_string(), + }; + return Some(msg); + } + } + } + None +} + /// Resolves Avro type names to [`AvroDataType`] /// /// See @@ -969,7 +1148,6 @@ impl<'a> Maker<'a> { )), Schema::TypeName(TypeName::Ref(name)) => self.resolver.resolve(name, namespace), Schema::Union(f) => { - // Special case the common case of nullable primitives let null = f .iter() .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); @@ -977,7 +1155,7 @@ impl<'a> Maker<'a> { (true, Some(0)) => { let mut field = self.parse_type(&f[1], namespace)?; field.nullability = Some(Nullability::NullFirst); - Ok(field) + return Ok(field); } (true, Some(1)) => { if self.strict_mode { @@ -988,12 +1166,34 @@ impl<'a> Maker<'a> { } let mut field = self.parse_type(&f[0], namespace)?; field.nullability = Some(Nullability::NullSecond); - Ok(field) + return Ok(field); } - _ => Err(ArrowError::NotYetImplemented(format!( - "Union of {f:?} not currently supported" - ))), + _ => {} + } + // Validate: unions may not immediately contain unions + if f.iter().any(|s| matches!(s, Schema::Union(_))) { + return Err(ArrowError::SchemaError( + "Avro unions may not immediately contain other unions".to_string(), + )); + } + // Validate: duplicates (named by full name; non-named by kind) + if let Some(dup) = union_first_duplicate(f, namespace) { + return Err(ArrowError::SchemaError(format!( + "Avro union contains duplicate branch type: {dup}" + ))); } + // Parse all branches + let children: Vec = f + .iter() + .map(|s| self.parse_type(s, namespace)) + .collect::>()?; + // Build Arrow layout once here + let union_fields = build_union_fields(&children); + Ok(AvroDataType::new( + Codec::Union(Arc::from(children), union_fields, UnionMode::Dense), + Default::default(), + None, + )) } Schema::Complex(c) => match c { ComplexType::Record(r) => { @@ -1149,6 +1349,67 @@ impl<'a> Maker<'a> { return self.resolve_primitives(write_primitive, read_primitive, reader_schema); } match (writer_schema, reader_schema) { + (Schema::Union(writer_variants), Schema::Union(reader_variants)) => { + match ( + nullable_union_variants(writer_variants.as_slice()), + nullable_union_variants(reader_variants.as_slice()), + ) { + (Some((w_nb, w_nonnull)), Some((_r_nb, r_nonnull))) => { + let mut dt = self.make_data_type(w_nonnull, Some(r_nonnull), namespace)?; + dt.nullability = Some(w_nb); + Ok(dt) + } + _ => self.resolve_unions( + writer_variants.as_slice(), + reader_variants.as_slice(), + namespace, + ), + } + } + (Schema::Union(writer_variants), reader_non_union) => { + let mut writer_to_reader: Vec> = + Vec::with_capacity(writer_variants.len()); + for writer in writer_variants { + match self.resolve_type(writer, reader_non_union, namespace) { + Ok(tmp) => writer_to_reader.push(Some((0usize, Self::coercion_from(&tmp)))), + Err(_) => writer_to_reader.push(None), + } + } + let mut dt = self.parse_type(reader_non_union, namespace)?; + dt.resolution = Some(ResolutionInfo::Union(ResolvedUnion { + writer_to_reader: Arc::from(writer_to_reader), + writer_is_union: true, + reader_is_union: false, + })); + Ok(dt) + } + (writer_non_union, Schema::Union(reader_variants)) => { + let mut direct: Option<(usize, Promotion)> = None; + let mut promo: Option<(usize, Promotion)> = None; + for (reader_index, reader) in reader_variants.iter().enumerate() { + if let Ok(tmp) = self.resolve_type(writer_non_union, reader, namespace) { + let how = Self::coercion_from(&tmp); + if how == Promotion::Direct { + direct = Some((reader_index, how)); + break; // first exact match wins + } else if promo.is_none() { + promo = Some((reader_index, how)); + } + } + } + let (reader_index, promotion) = direct.or(promo).ok_or_else(|| { + ArrowError::SchemaError( + "Writer schema does not match any reader union branch".to_string(), + ) + })?; + let mut dt = self.parse_type(reader_schema, namespace)?; + dt.resolution = Some(ResolutionInfo::Union(ResolvedUnion { + writer_to_reader: Arc::from(vec![Some((reader_index, promotion))]), + writer_is_union: false, + reader_is_union: true, + })); + Ok(dt) + } ( Schema::Complex(ComplexType::Array(writer_array)), Schema::Complex(ComplexType::Array(reader_array)), @@ -1169,12 +1430,6 @@ impl<'a> Maker<'a> { Schema::Complex(ComplexType::Enum(writer_enum)), Schema::Complex(ComplexType::Enum(reader_enum)), ) => self.resolve_enums(writer_enum, reader_enum, reader_schema, namespace), - (Schema::Union(writer_variants), Schema::Union(reader_variants)) => self - .resolve_nullable_union( - writer_variants.as_slice(), - reader_variants.as_slice(), - namespace, - ), (Schema::TypeName(TypeName::Ref(_)), _) => self.parse_type(reader_schema, namespace), (_, Schema::TypeName(TypeName::Ref(_))) => self.parse_type(reader_schema, namespace), _ => Err(ArrowError::NotYetImplemented( @@ -1183,6 +1438,56 @@ impl<'a> Maker<'a> { } } + #[inline] + fn coercion_from(dt: &AvroDataType) -> Promotion { + match dt.resolution.as_ref() { + Some(ResolutionInfo::Promotion(promotion)) => *promotion, + _ => Promotion::Direct, + } + } + + fn resolve_unions<'s>( + &mut self, + writer_variants: &'s [Schema<'a>], + reader_variants: &'s [Schema<'a>], + namespace: Option<&'a str>, + ) -> Result { + let reader_encodings: Vec = reader_variants + .iter() + .map(|reader_schema| self.parse_type(reader_schema, namespace)) + .collect::>()?; + let mut writer_to_reader: Vec> = + Vec::with_capacity(writer_variants.len()); + for writer in writer_variants { + let mut direct: Option<(usize, Promotion)> = None; + let mut promo: Option<(usize, Promotion)> = None; + for (reader_index, reader) in reader_variants.iter().enumerate() { + if let Ok(tmp) = self.resolve_type(writer, reader, namespace) { + let promotion = Self::coercion_from(&tmp); + if promotion == Promotion::Direct { + direct = Some((reader_index, promotion)); + break; + } else if promo.is_none() { + promo = Some((reader_index, promotion)); + } + } + } + writer_to_reader.push(direct.or(promo)); + } + let union_fields = build_union_fields(&reader_encodings); + let mut dt = AvroDataType::new( + Codec::Union(reader_encodings.into(), union_fields, UnionMode::Dense), + Default::default(), + None, + ); + dt.resolution = Some(ResolutionInfo::Union(ResolvedUnion { + writer_to_reader: Arc::from(writer_to_reader), + writer_is_union: true, + reader_is_union: true, + })); + Ok(dt) + } + fn resolve_array( &mut self, writer_array: &Array<'a>, @@ -1281,10 +1586,9 @@ impl<'a> Maker<'a> { nullable_union_variants(writer_variants), nullable_union_variants(reader_variants), ) { - (Some((_, write_nonnull)), Some((read_nb, read_nonnull))) => { + (Some((write_nb, write_nonnull)), Some((_read_nb, read_nonnull))) => { let mut dt = self.make_data_type(write_nonnull, Some(read_nonnull), namespace)?; - // Adopt reader union null ordering - dt.nullability = Some(read_nb); + dt.nullability = Some(write_nb); Ok(dt) } _ => Err(ArrowError::NotYetImplemented( @@ -1557,6 +1861,24 @@ mod tests { .expect("promotion should resolve") } + fn mk_primitive(pt: PrimitiveType) -> Schema<'static> { + Schema::TypeName(TypeName::Primitive(pt)) + } + fn mk_union(branches: Vec>) -> Schema<'static> { + Schema::Union(branches) + } + + fn mk_record_named(name: &'static str) -> Schema<'static> { + Schema::Complex(ComplexType::Record(Record { + name, + namespace: None, + doc: None, + aliases: vec![], + fields: vec![], + attributes: Attributes::default(), + })) + } + #[test] fn test_date_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Int, "date"); @@ -1842,7 +2164,7 @@ mod tests { } #[test] - fn test_promotion_within_nullable_union_keeps_reader_null_ordering() { + fn test_promotion_within_nullable_union_keeps_writer_null_ordering() { let writer = Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), @@ -1858,7 +2180,105 @@ mod tests { result.resolution, Some(ResolutionInfo::Promotion(Promotion::IntToDouble)) ); - assert_eq!(result.nullability, Some(Nullability::NullSecond)); + assert_eq!(result.nullability, Some(Nullability::NullFirst)); + } + + #[test] + fn test_resolve_writer_union_to_reader_non_union_partial_coverage() { + let writer = mk_union(vec![ + mk_primitive(PrimitiveType::String), + mk_primitive(PrimitiveType::Long), + ]); + let reader = mk_primitive(PrimitiveType::Bytes); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + assert!(matches!(dt.codec(), Codec::Binary)); + let resolved = match dt.resolution { + Some(ResolutionInfo::Union(u)) => u, + other => panic!("expected union resolution info, got {other:?}"), + }; + assert!(resolved.writer_is_union && !resolved.reader_is_union); + assert_eq!( + resolved.writer_to_reader.as_ref(), + &[Some((0, Promotion::StringToBytes)), None] + ); + } + + #[test] + fn test_resolve_writer_non_union_to_reader_union_prefers_direct_over_promotion() { + let writer = mk_primitive(PrimitiveType::Long); + let reader = mk_union(vec![ + mk_primitive(PrimitiveType::Long), + mk_primitive(PrimitiveType::Double), + ]); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + let resolved = match dt.resolution { + Some(ResolutionInfo::Union(u)) => u, + other => panic!("expected union resolution info, got {other:?}"), + }; + assert!(!resolved.writer_is_union && resolved.reader_is_union); + assert_eq!( + resolved.writer_to_reader.as_ref(), + &[Some((0, Promotion::Direct))] + ); + } + + #[test] + fn test_resolve_writer_non_union_to_reader_union_uses_promotion_when_needed() { + let writer = mk_primitive(PrimitiveType::Int); + let reader = mk_union(vec![ + mk_primitive(PrimitiveType::Null), + mk_primitive(PrimitiveType::Long), + mk_primitive(PrimitiveType::String), + ]); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + let resolved = match dt.resolution { + Some(ResolutionInfo::Union(u)) => u, + other => panic!("expected union resolution info, got {other:?}"), + }; + assert_eq!( + resolved.writer_to_reader.as_ref(), + &[Some((1, Promotion::IntToLong))] + ); + } + + #[test] + fn test_resolve_both_nullable_unions_direct_match() { + let writer = mk_union(vec![ + mk_primitive(PrimitiveType::Null), + mk_primitive(PrimitiveType::String), + ]); + let reader = mk_union(vec![ + mk_primitive(PrimitiveType::String), + mk_primitive(PrimitiveType::Null), + ]); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + assert!(matches!(dt.codec(), Codec::Utf8)); + assert_eq!(dt.nullability, Some(Nullability::NullFirst)); + assert!(dt.resolution.is_none()); + } + + #[test] + fn test_resolve_both_nullable_unions_with_promotion() { + let writer = mk_union(vec![ + mk_primitive(PrimitiveType::Null), + mk_primitive(PrimitiveType::Int), + ]); + let reader = mk_union(vec![ + mk_primitive(PrimitiveType::Double), + mk_primitive(PrimitiveType::Null), + ]); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + assert!(matches!(dt.codec(), Codec::Float64)); + assert_eq!(dt.nullability, Some(Nullability::NullFirst)); + assert_eq!( + dt.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToDouble)) + ); } #[test] diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 217366b63318..6bdaf2f98f7a 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -983,11 +983,13 @@ mod test { use arrow_array::types::{Int32Type, IntervalMonthDayNanoType}; use arrow_array::*; use arrow_buffer::{i256, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; - use arrow_schema::{ArrowError, DataType, Field, Fields, IntervalUnit, Schema}; + use arrow_schema::{ + ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, Schema, UnionFields, UnionMode, + }; use bytes::{Buf, BufMut, Bytes}; use futures::executor::block_on; use futures::{stream, Stream, StreamExt, TryStreamExt}; - use serde_json::Value; + use serde_json::{json, Value}; use std::collections::HashMap; use std::fs; use std::fs::File; @@ -2085,6 +2087,245 @@ mod test { assert!(batch.column(0).as_any().is::()); } + fn make_reader_schema_with_default_fields( + path: &str, + default_fields: Vec, + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + root.as_object_mut() + .expect("schema is a JSON object") + .insert("fields".to_string(), Value::Array(default_fields)); + AvroSchema::new(root.to_string()) + } + + #[test] + fn test_schema_resolution_defaults_all_supported_types() { + let path = "test/data/skippable_types.avro"; + let duration_default = "\u{0000}".repeat(12); + let reader_schema = make_reader_schema_with_default_fields( + path, + vec![ + serde_json::json!({"name":"d_bool","type":"boolean","default":true}), + serde_json::json!({"name":"d_int","type":"int","default":42}), + serde_json::json!({"name":"d_long","type":"long","default":12345}), + serde_json::json!({"name":"d_float","type":"float","default":1.5}), + serde_json::json!({"name":"d_double","type":"double","default":2.25}), + serde_json::json!({"name":"d_bytes","type":"bytes","default":"XYZ"}), + serde_json::json!({"name":"d_string","type":"string","default":"hello"}), + serde_json::json!({"name":"d_date","type":{"type":"int","logicalType":"date"},"default":0}), + serde_json::json!({"name":"d_time_ms","type":{"type":"int","logicalType":"time-millis"},"default":1000}), + serde_json::json!({"name":"d_time_us","type":{"type":"long","logicalType":"time-micros"},"default":2000}), + serde_json::json!({"name":"d_ts_ms","type":{"type":"long","logicalType":"local-timestamp-millis"},"default":0}), + serde_json::json!({"name":"d_ts_us","type":{"type":"long","logicalType":"local-timestamp-micros"},"default":0}), + serde_json::json!({"name":"d_decimal","type":{"type":"bytes","logicalType":"decimal","precision":10,"scale":2},"default":""}), + serde_json::json!({"name":"d_fixed","type":{"type":"fixed","name":"F4","size":4},"default":"ABCD"}), + serde_json::json!({"name":"d_enum","type":{"type":"enum","name":"E","symbols":["A","B","C"]},"default":"A"}), + serde_json::json!({"name":"d_duration","type":{"type":"fixed","name":"Dur","size":12,"logicalType":"duration"},"default":duration_default}), + serde_json::json!({"name":"d_uuid","type":{"type":"string","logicalType":"uuid"},"default":"00000000-0000-0000-0000-000000000000"}), + serde_json::json!({"name":"d_array","type":{"type":"array","items":"int"},"default":[1,2,3]}), + serde_json::json!({"name":"d_map","type":{"type":"map","values":"long"},"default":{"a":1,"b":2}}), + serde_json::json!({"name":"d_record","type":{ + "type":"record","name":"DefaultRec","fields":[ + {"name":"x","type":"int"}, + {"name":"y","type":["null","string"],"default":null} + ] + },"default":{"x":7}}), + serde_json::json!({"name":"d_nullable_null","type":["null","int"],"default":null}), + serde_json::json!({"name":"d_nullable_value","type":["int","null"],"default":123}), + ], + ); + let actual = read_alltypes_with_reader_schema(path, reader_schema); + let num_rows = actual.num_rows(); + assert!(num_rows > 0, "skippable_types.avro should contain rows"); + assert_eq!( + actual.num_columns(), + 22, + "expected exactly our defaulted fields" + ); + let mut arrays: Vec> = Vec::with_capacity(22); + arrays.push(Arc::new(BooleanArray::from_iter(std::iter::repeat_n( + Some(true), + num_rows, + )))); + arrays.push(Arc::new(Int32Array::from_iter_values(std::iter::repeat_n( + 42, num_rows, + )))); + arrays.push(Arc::new(Int64Array::from_iter_values(std::iter::repeat_n( + 12345, num_rows, + )))); + arrays.push(Arc::new(Float32Array::from_iter_values( + std::iter::repeat_n(1.5f32, num_rows), + ))); + arrays.push(Arc::new(Float64Array::from_iter_values( + std::iter::repeat_n(2.25f64, num_rows), + ))); + arrays.push(Arc::new(BinaryArray::from_iter_values( + std::iter::repeat_n(b"XYZ".as_ref(), num_rows), + ))); + arrays.push(Arc::new(StringArray::from_iter_values( + std::iter::repeat_n("hello", num_rows), + ))); + arrays.push(Arc::new(Date32Array::from_iter_values( + std::iter::repeat_n(0, num_rows), + ))); + arrays.push(Arc::new(Time32MillisecondArray::from_iter_values( + std::iter::repeat_n(1_000, num_rows), + ))); + arrays.push(Arc::new(Time64MicrosecondArray::from_iter_values( + std::iter::repeat_n(2_000i64, num_rows), + ))); + arrays.push(Arc::new(TimestampMillisecondArray::from_iter_values( + std::iter::repeat_n(0i64, num_rows), + ))); + arrays.push(Arc::new(TimestampMicrosecondArray::from_iter_values( + std::iter::repeat_n(0i64, num_rows), + ))); + #[cfg(feature = "small_decimals")] + let decimal = Decimal64Array::from_iter_values(std::iter::repeat_n(0i64, num_rows)) + .with_precision_and_scale(10, 2) + .unwrap(); + #[cfg(not(feature = "small_decimals"))] + let decimal = Decimal128Array::from_iter_values(std::iter::repeat_n(0i128, num_rows)) + .with_precision_and_scale(10, 2) + .unwrap(); + arrays.push(Arc::new(decimal)); + let fixed_iter = std::iter::repeat_n(Some(*b"ABCD"), num_rows); + arrays.push(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(fixed_iter, 4).unwrap(), + )); + let enum_keys = Int32Array::from_iter_values(std::iter::repeat_n(0, num_rows)); + let enum_values = StringArray::from_iter_values(["A", "B", "C"]); + let enum_arr = + DictionaryArray::::try_new(enum_keys, Arc::new(enum_values)).unwrap(); + arrays.push(Arc::new(enum_arr)); + let duration_values = std::iter::repeat_n( + Some(IntervalMonthDayNanoType::make_value(0, 0, 0)), + num_rows, + ); + let duration_arr: IntervalMonthDayNanoArray = duration_values.collect(); + arrays.push(Arc::new(duration_arr)); + let uuid_bytes = [0u8; 16]; + let uuid_iter = std::iter::repeat_n(Some(uuid_bytes), num_rows); + arrays.push(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(uuid_iter, 16).unwrap(), + )); + let item_field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + DataType::Int32, + false, + )); + let mut list_builder = ListBuilder::new(Int32Builder::new()).with_field(item_field); + for _ in 0..num_rows { + list_builder.values().append_value(1); + list_builder.values().append_value(2); + list_builder.values().append_value(3); + list_builder.append(true); + } + arrays.push(Arc::new(list_builder.finish())); + let values_field = Arc::new(Field::new("value", DataType::Int64, false)); + let mut map_builder = MapBuilder::new( + Some(builder::MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }), + StringBuilder::new(), + Int64Builder::new(), + ) + .with_values_field(values_field); + for _ in 0..num_rows { + let (keys, vals) = map_builder.entries(); + keys.append_value("a"); + vals.append_value(1); + keys.append_value("b"); + vals.append_value(2); + map_builder.append(true).unwrap(); + } + arrays.push(Arc::new(map_builder.finish())); + let rec_fields: Fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, true), + ]); + let mut sb = StructBuilder::new( + rec_fields.clone(), + vec![ + Box::new(Int32Builder::new()), + Box::new(StringBuilder::new()), + ], + ); + for _ in 0..num_rows { + sb.field_builder::(0).unwrap().append_value(7); + sb.field_builder::(1).unwrap().append_null(); + sb.append(true); + } + arrays.push(Arc::new(sb.finish())); + arrays.push(Arc::new(Int32Array::from_iter(std::iter::repeat_n( + None::, + num_rows, + )))); + arrays.push(Arc::new(Int32Array::from_iter_values(std::iter::repeat_n( + 123, num_rows, + )))); + let expected = RecordBatch::try_new(actual.schema(), arrays).unwrap(); + assert_eq!( + actual, expected, + "defaults should materialize correctly for all fields" + ); + } + + #[test] + fn test_schema_resolution_default_enum_invalid_symbol_errors() { + let path = "test/data/skippable_types.avro"; + let bad_schema = make_reader_schema_with_default_fields( + path, + vec![serde_json::json!({ + "name":"bad_enum", + "type":{"type":"enum","name":"E","symbols":["A","B","C"]}, + "default":"Z" + })], + ); + let file = File::open(path).unwrap(); + let res = ReaderBuilder::new() + .with_reader_schema(bad_schema) + .build(BufReader::new(file)); + let err = res.expect_err("expected enum default validation to fail"); + let msg = err.to_string(); + let lower_msg = msg.to_lowercase(); + assert!( + lower_msg.contains("enum") + && (lower_msg.contains("symbol") || lower_msg.contains("default")), + "unexpected error: {msg}" + ); + } + + #[test] + fn test_schema_resolution_default_fixed_size_mismatch_errors() { + let path = "test/data/skippable_types.avro"; + let bad_schema = make_reader_schema_with_default_fields( + path, + vec![serde_json::json!({ + "name":"bad_fixed", + "type":{"type":"fixed","name":"F","size":4}, + "default":"ABC" + })], + ); + let file = File::open(path).unwrap(); + let res = ReaderBuilder::new() + .with_reader_schema(bad_schema) + .build(BufReader::new(file)); + let err = res.expect_err("expected fixed default validation to fail"); + let msg = err.to_string(); + let lower_msg = msg.to_lowercase(); + assert!( + lower_msg.contains("fixed") + && (lower_msg.contains("size") + || lower_msg.contains("length") + || lower_msg.contains("does not match")), + "unexpected error: {msg}" + ); + } + #[test] fn test_alltypes_skip_writer_fields_keep_double_only() { let file = arrow_test_data("avro/alltypes_plain.avro"); @@ -2186,6 +2427,763 @@ mod test { } } + #[test] + fn test_union_fields_avro_nullable_and_general_unions() { + let path = "test/data/union_fields.avro"; + let batch = read_file(path, 1024, false); + let schema = batch.schema(); + let idx = schema.index_of("nullable_int_nullfirst").unwrap(); + let a = batch + .column(idx) + .as_any() + .downcast_ref::() + .expect("nullable_int_nullfirst should be Int32"); + assert_eq!(a.len(), 4); + assert!(a.is_null(0)); + assert_eq!(a.value(1), 42); + assert!(a.is_null(2)); + assert_eq!(a.value(3), 0); + let idx = schema.index_of("nullable_string_nullsecond").unwrap(); + let s = batch + .column(idx) + .as_any() + .downcast_ref::() + .expect("nullable_string_nullsecond should be Utf8"); + assert_eq!(s.len(), 4); + assert_eq!(s.value(0), "s1"); + assert!(s.is_null(1)); + assert_eq!(s.value(2), "s3"); + assert!(s.is_valid(3)); // empty string, not null + assert_eq!(s.value(3), ""); + let idx = schema.index_of("union_prim").unwrap(); + let u = batch + .column(idx) + .as_any() + .downcast_ref::() + .expect("union_prim should be Union"); + let fields = match u.data_type() { + DataType::Union(fields, mode) => { + assert!(matches!(mode, UnionMode::Dense), "expect dense unions"); + fields + } + other => panic!("expected Union, got {other:?}"), + }; + let tid_by_name = |name: &str| -> i8 { + for (tid, f) in fields.iter() { + if f.name() == name { + return tid; + } + } + panic!("union child '{name}' not found"); + }; + let expected_type_ids = vec![ + tid_by_name("long"), + tid_by_name("int"), + tid_by_name("float"), + tid_by_name("double"), + ]; + let type_ids: Vec = u.type_ids().iter().copied().collect(); + assert_eq!( + type_ids, expected_type_ids, + "branch selection for union_prim rows" + ); + let longs = u + .child(tid_by_name("long")) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(longs.len(), 1); + let ints = u + .child(tid_by_name("int")) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(ints.len(), 1); + let floats = u + .child(tid_by_name("float")) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(floats.len(), 1); + let doubles = u + .child(tid_by_name("double")) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(doubles.len(), 1); + let idx = schema.index_of("union_bytes_vs_string").unwrap(); + let u = batch + .column(idx) + .as_any() + .downcast_ref::() + .expect("union_bytes_vs_string should be Union"); + let fields = match u.data_type() { + DataType::Union(fields, _) => fields, + other => panic!("expected Union, got {other:?}"), + }; + let tid_by_name = |name: &str| -> i8 { + for (tid, f) in fields.iter() { + if f.name() == name { + return tid; + } + } + panic!("union child '{name}' not found"); + }; + let tid_bytes = tid_by_name("bytes"); + let tid_string = tid_by_name("string"); + let type_ids: Vec = u.type_ids().iter().copied().collect(); + assert_eq!( + type_ids, + vec![tid_bytes, tid_string, tid_string, tid_bytes], + "branch selection for bytes/string union" + ); + let s_child = u + .child(tid_string) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(s_child.len(), 2); + assert_eq!(s_child.value(0), "hello"); + assert_eq!(s_child.value(1), "world"); + let b_child = u + .child(tid_bytes) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(b_child.len(), 2); + assert_eq!(b_child.value(0), &[0x00, 0xFF, 0x7F]); + assert_eq!(b_child.value(1), b""); // previously: &[] + let idx = schema.index_of("union_enum_records_array_map").unwrap(); + let u = batch + .column(idx) + .as_any() + .downcast_ref::() + .expect("union_enum_records_array_map should be Union"); + let fields = match u.data_type() { + DataType::Union(fields, _) => fields, + other => panic!("expected Union, got {other:?}"), + }; + let mut tid_enum: Option = None; + let mut tid_rec_a: Option = None; + let mut tid_rec_b: Option = None; + let mut tid_array: Option = None; + let mut tid_map: Option = None; + for (tid, f) in fields.iter() { + match f.data_type() { + DataType::Dictionary(_, _) => tid_enum = Some(tid), + DataType::Struct(childs) => { + if childs.len() == 2 && childs[0].name() == "a" && childs[1].name() == "b" { + tid_rec_a = Some(tid); + } else if childs.len() == 2 + && childs[0].name() == "x" + && childs[1].name() == "y" + { + tid_rec_b = Some(tid); + } + } + DataType::List(_) => tid_array = Some(tid), + DataType::Map(_, _) => tid_map = Some(tid), + _ => {} + } + } + let (tid_enum, tid_rec_a, tid_rec_b, tid_array) = ( + tid_enum.expect("enum child"), + tid_rec_a.expect("RecA child"), + tid_rec_b.expect("RecB child"), + tid_array.expect("array child"), + ); + let type_ids: Vec = u.type_ids().iter().copied().collect(); + assert_eq!( + type_ids, + vec![tid_enum, tid_rec_a, tid_rec_b, tid_array], + "branch selection for complex union" + ); + let dict = u + .child(tid_enum) + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(dict.len(), 1); + assert!(dict.is_valid(0)); + let rec_a = u + .child(tid_rec_a) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(rec_a.len(), 1); + let a_val = rec_a + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a_val.value(0), 7); + let b_val = rec_a + .column_by_name("b") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(b_val.value(0), "x"); + // RecB row: {"x": 123456789, "y": b"\xFF\x00"} + let rec_b = u + .child(tid_rec_b) + .as_any() + .downcast_ref::() + .unwrap(); + let x_val = rec_b + .column_by_name("x") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(x_val.value(0), 123_456_789_i64); + let y_val = rec_b + .column_by_name("y") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(y_val.value(0), &[0xFF, 0x00]); + let arr = u + .child(tid_array) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(arr.len(), 1); + let first_values = arr.value(0); + let longs = first_values.as_any().downcast_ref::().unwrap(); + assert_eq!(longs.len(), 3); + assert_eq!(longs.value(0), 1); + assert_eq!(longs.value(1), 2); + assert_eq!(longs.value(2), 3); + let idx = schema.index_of("union_date_or_fixed4").unwrap(); + let u = batch + .column(idx) + .as_any() + .downcast_ref::() + .expect("union_date_or_fixed4 should be Union"); + let fields = match u.data_type() { + DataType::Union(fields, _) => fields, + other => panic!("expected Union, got {other:?}"), + }; + let mut tid_date: Option = None; + let mut tid_fixed: Option = None; + for (tid, f) in fields.iter() { + match f.data_type() { + DataType::Date32 => tid_date = Some(tid), + DataType::FixedSizeBinary(4) => tid_fixed = Some(tid), + _ => {} + } + } + let (tid_date, tid_fixed) = (tid_date.expect("date"), tid_fixed.expect("fixed(4)")); + let type_ids: Vec = u.type_ids().iter().copied().collect(); + assert_eq!( + type_ids, + vec![tid_date, tid_fixed, tid_date, tid_fixed], + "branch selection for date/fixed4 union" + ); + let dates = u + .child(tid_date) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(dates.len(), 2); + assert_eq!(dates.value(0), 19_000); // ~2022‑01‑15 + assert_eq!(dates.value(1), 0); // epoch + let fixed = u + .child(tid_fixed) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(fixed.len(), 2); + assert_eq!(fixed.value(0), b"ABCD"); + assert_eq!(fixed.value(1), &[0x00, 0x11, 0x22, 0x33]); + } + + #[test] + fn test_union_schema_resolution_all_type_combinations() { + let path = "test/data/union_fields.avro"; + let baseline = read_file(path, 1024, false); + let baseline_schema = baseline.schema(); + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let fields = root + .get_mut("fields") + .and_then(|f| f.as_array_mut()) + .expect("record has fields"); + fn is_named_type(obj: &Value, ty: &str, nm: &str) -> bool { + obj.get("type").and_then(|v| v.as_str()) == Some(ty) + && obj.get("name").and_then(|v| v.as_str()) == Some(nm) + } + fn is_logical(obj: &Value, prim: &str, lt: &str) -> bool { + obj.get("type").and_then(|v| v.as_str()) == Some(prim) + && obj.get("logicalType").and_then(|v| v.as_str()) == Some(lt) + } + fn find_first(arr: &[Value], pred: impl Fn(&Value) -> bool) -> Option { + arr.iter().find(|v| pred(v)).cloned() + } + fn prim(s: &str) -> Value { + Value::String(s.to_string()) + } + for f in fields.iter_mut() { + let Some(name) = f.get("name").and_then(|n| n.as_str()) else { + continue; + }; + match name { + // Flip null ordering – should not affect values + "nullable_int_nullfirst" => { + f["type"] = json!(["int", "null"]); + } + "nullable_string_nullsecond" => { + f["type"] = json!(["null", "string"]); + } + "union_prim" => { + let orig = f["type"].as_array().unwrap().clone(); + let long = prim("long"); + let double = prim("double"); + let string = prim("string"); + let bytes = prim("bytes"); + let boolean = prim("boolean"); + assert!(orig.contains(&long)); + assert!(orig.contains(&double)); + assert!(orig.contains(&string)); + assert!(orig.contains(&bytes)); + assert!(orig.contains(&boolean)); + f["type"] = json!([long, double, string, bytes, boolean]); + } + "union_bytes_vs_string" => { + f["type"] = json!(["string", "bytes"]); + } + "union_fixed_dur_decfix" => { + let orig = f["type"].as_array().unwrap().clone(); + let fx8 = find_first(&orig, |o| is_named_type(o, "fixed", "Fx8")).unwrap(); + let dur12 = find_first(&orig, |o| is_named_type(o, "fixed", "Dur12")).unwrap(); + let decfix16 = + find_first(&orig, |o| is_named_type(o, "fixed", "DecFix16")).unwrap(); + f["type"] = json!([decfix16, dur12, fx8]); + } + "union_enum_records_array_map" => { + let orig = f["type"].as_array().unwrap().clone(); + let enum_color = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("enum") + }) + .unwrap(); + let rec_a = find_first(&orig, |o| is_named_type(o, "record", "RecA")).unwrap(); + let rec_b = find_first(&orig, |o| is_named_type(o, "record", "RecB")).unwrap(); + let arr = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("array") + }) + .unwrap(); + let map = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("map") + }) + .unwrap(); + f["type"] = json!([arr, map, rec_b, rec_a, enum_color]); + } + "union_date_or_fixed4" => { + let orig = f["type"].as_array().unwrap().clone(); + let date = find_first(&orig, |o| is_logical(o, "int", "date")).unwrap(); + let fx4 = find_first(&orig, |o| is_named_type(o, "fixed", "Fx4")).unwrap(); + f["type"] = json!([fx4, date]); + } + "union_time_millis_or_enum" => { + let orig = f["type"].as_array().unwrap().clone(); + let time_ms = + find_first(&orig, |o| is_logical(o, "int", "time-millis")).unwrap(); + let en = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("enum") + }) + .unwrap(); + f["type"] = json!([en, time_ms]); + } + "union_time_micros_or_string" => { + let orig = f["type"].as_array().unwrap().clone(); + let time_us = + find_first(&orig, |o| is_logical(o, "long", "time-micros")).unwrap(); + f["type"] = json!(["string", time_us]); + } + "union_ts_millis_utc_or_array" => { + let orig = f["type"].as_array().unwrap().clone(); + let ts_ms = + find_first(&orig, |o| is_logical(o, "long", "timestamp-millis")).unwrap(); + let arr = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("array") + }) + .unwrap(); + f["type"] = json!([arr, ts_ms]); + } + "union_ts_micros_local_or_bytes" => { + let orig = f["type"].as_array().unwrap().clone(); + let lts_us = + find_first(&orig, |o| is_logical(o, "long", "local-timestamp-micros")) + .unwrap(); + f["type"] = json!(["bytes", lts_us]); + } + "union_uuid_or_fixed10" => { + let orig = f["type"].as_array().unwrap().clone(); + let uuid = find_first(&orig, |o| is_logical(o, "string", "uuid")).unwrap(); + let fx10 = find_first(&orig, |o| is_named_type(o, "fixed", "Fx10")).unwrap(); + f["type"] = json!([fx10, uuid]); + } + "union_dec_bytes_or_dec_fixed" => { + let orig = f["type"].as_array().unwrap().clone(); + let dec_bytes = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("bytes") + && o.get("logicalType").and_then(|v| v.as_str()) == Some("decimal") + }) + .unwrap(); + let dec_fix = find_first(&orig, |o| { + is_named_type(o, "fixed", "DecFix20") + && o.get("logicalType").and_then(|v| v.as_str()) == Some("decimal") + }) + .unwrap(); + f["type"] = json!([dec_fix, dec_bytes]); + } + "union_null_bytes_string" => { + f["type"] = json!(["bytes", "string", "null"]); + } + "array_of_union" => { + let obj = f + .get_mut("type") + .expect("array type") + .as_object_mut() + .unwrap(); + obj.insert("items".to_string(), json!(["string", "long"])); + } + "map_of_union" => { + let obj = f + .get_mut("type") + .expect("map type") + .as_object_mut() + .unwrap(); + obj.insert("values".to_string(), json!(["double", "null"])); + } + "record_with_union_field" => { + let rec = f + .get_mut("type") + .expect("record type") + .as_object_mut() + .unwrap(); + let rec_fields = rec.get_mut("fields").unwrap().as_array_mut().unwrap(); + let mut found = false; + for rf in rec_fields.iter_mut() { + if rf.get("name").and_then(|v| v.as_str()) == Some("u") { + rf["type"] = json!(["string", "long"]); // rely on int→long promotion + found = true; + break; + } + } + assert!(found, "field 'u' expected in HasUnion"); + } + "union_ts_micros_utc_or_map" => { + let orig = f["type"].as_array().unwrap().clone(); + let ts_us = + find_first(&orig, |o| is_logical(o, "long", "timestamp-micros")).unwrap(); + let map = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("map") + }) + .unwrap(); + f["type"] = json!([map, ts_us]); + } + "union_ts_millis_local_or_string" => { + let orig = f["type"].as_array().unwrap().clone(); + let lts_ms = + find_first(&orig, |o| is_logical(o, "long", "local-timestamp-millis")) + .unwrap(); + f["type"] = json!(["string", lts_ms]); + } + "union_bool_or_string" => { + f["type"] = json!(["string", "boolean"]); + } + _ => {} + } + } + let reader_schema = AvroSchema::new(root.to_string()); + let resolved = read_alltypes_with_reader_schema(path, reader_schema); + + fn branch_token(dt: &DataType) -> String { + match dt { + DataType::Null => "null".into(), + DataType::Boolean => "boolean".into(), + DataType::Int32 => "int".into(), + DataType::Int64 => "long".into(), + DataType::Float32 => "float".into(), + DataType::Float64 => "double".into(), + DataType::Binary => "bytes".into(), + DataType::Utf8 => "string".into(), + DataType::Date32 => "date".into(), + DataType::Time32(arrow_schema::TimeUnit::Millisecond) => "time-millis".into(), + DataType::Time64(arrow_schema::TimeUnit::Microsecond) => "time-micros".into(), + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz) => if tz.is_some() { + "timestamp-millis" + } else { + "local-timestamp-millis" + } + .into(), + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz) => if tz.is_some() { + "timestamp-micros" + } else { + "local-timestamp-micros" + } + .into(), + DataType::Interval(IntervalUnit::MonthDayNano) => "duration".into(), + DataType::FixedSizeBinary(n) => format!("fixed{n}"), + DataType::Dictionary(_, _) => "enum".into(), + DataType::Decimal128(p, s) => format!("decimal({p},{s})"), + DataType::Decimal256(p, s) => format!("decimal({p},{s})"), + #[cfg(feature = "small_decimals")] + DataType::Decimal64(p, s) => format!("decimal({p},{s})"), + DataType::Struct(fields) => { + if fields.len() == 2 && fields[0].name() == "a" && fields[1].name() == "b" { + "record:RecA".into() + } else if fields.len() == 2 + && fields[0].name() == "x" + && fields[1].name() == "y" + { + "record:RecB".into() + } else { + "record".into() + } + } + DataType::List(_) => "array".into(), + DataType::Map(_, _) => "map".into(), + other => format!("{other:?}"), + } + } + + fn union_tokens(u: &UnionArray) -> (Vec, HashMap) { + let fields = match u.data_type() { + DataType::Union(fields, _) => fields, + other => panic!("expected Union, got {other:?}"), + }; + let mut dict: HashMap = HashMap::with_capacity(fields.len()); + for (tid, f) in fields.iter() { + dict.insert(tid, branch_token(f.data_type())); + } + let ids: Vec = u.type_ids().iter().copied().collect(); + (ids, dict) + } + + fn expected_token(field_name: &str, writer_token: &str) -> String { + match field_name { + "union_prim" => match writer_token { + "int" => "long".into(), + "float" => "double".into(), + other => other.into(), + }, + "record_with_union_field.u" => match writer_token { + "int" => "long".into(), + other => other.into(), + }, + _ => writer_token.into(), + } + } + + fn get_union<'a>( + rb: &'a RecordBatch, + schema: arrow_schema::SchemaRef, + fname: &str, + ) -> &'a UnionArray { + let idx = schema.index_of(fname).unwrap(); + rb.column(idx) + .as_any() + .downcast_ref::() + .unwrap_or_else(|| panic!("{fname} should be a Union")) + } + + fn assert_union_equivalent(field_name: &str, u_writer: &UnionArray, u_reader: &UnionArray) { + let (ids_w, dict_w) = union_tokens(u_writer); + let (ids_r, dict_r) = union_tokens(u_reader); + assert_eq!( + ids_w.len(), + ids_r.len(), + "{field_name}: row count mismatch between baseline and resolved" + ); + for (i, (id_w, id_r)) in ids_w.iter().zip(ids_r.iter()).enumerate() { + let w_tok = dict_w.get(id_w).unwrap(); + let want = expected_token(field_name, w_tok); + let got = dict_r.get(id_r).unwrap(); + assert_eq!( + got, &want, + "{field_name}: row {i} resolved to wrong union branch (writer={w_tok}, expected={want}, got={got})" + ); + } + } + + for (fname, dt) in [ + ("nullable_int_nullfirst", DataType::Int32), + ("nullable_string_nullsecond", DataType::Utf8), + ] { + let idx_b = baseline_schema.index_of(fname).unwrap(); + let idx_r = resolved.schema().index_of(fname).unwrap(); + let col_b = baseline.column(idx_b); + let col_r = resolved.column(idx_r); + assert_eq!( + col_b.data_type(), + &dt, + "baseline {fname} should decode as non-union with nullability" + ); + assert_eq!( + col_b.as_ref(), + col_r.as_ref(), + "{fname}: values must be identical regardless of null-branch order" + ); + } + let union_fields = [ + "union_prim", + "union_bytes_vs_string", + "union_fixed_dur_decfix", + "union_enum_records_array_map", + "union_date_or_fixed4", + "union_time_millis_or_enum", + "union_time_micros_or_string", + "union_ts_millis_utc_or_array", + "union_ts_micros_local_or_bytes", + "union_uuid_or_fixed10", + "union_dec_bytes_or_dec_fixed", + "union_null_bytes_string", + "union_ts_micros_utc_or_map", + "union_ts_millis_local_or_string", + "union_bool_or_string", + ]; + for fname in union_fields { + let u_b = get_union(&baseline, baseline_schema.clone(), fname); + let u_r = get_union(&resolved, resolved.schema(), fname); + assert_union_equivalent(fname, u_b, u_r); + } + { + let fname = "array_of_union"; + let idx_b = baseline_schema.index_of(fname).unwrap(); + let idx_r = resolved.schema().index_of(fname).unwrap(); + let arr_b = baseline + .column(idx_b) + .as_any() + .downcast_ref::() + .expect("array_of_union should be a List"); + let arr_r = resolved + .column(idx_r) + .as_any() + .downcast_ref::() + .expect("array_of_union should be a List"); + assert_eq!( + arr_b.value_offsets(), + arr_r.value_offsets(), + "{fname}: list offsets changed after resolution" + ); + let u_b = arr_b + .values() + .as_any() + .downcast_ref::() + .expect("array items should be Union"); + let u_r = arr_r + .values() + .as_any() + .downcast_ref::() + .expect("array items should be Union"); + let (ids_b, dict_b) = union_tokens(u_b); + let (ids_r, dict_r) = union_tokens(u_r); + assert_eq!(ids_b.len(), ids_r.len(), "{fname}: values length mismatch"); + for (i, (id_b, id_r)) in ids_b.iter().zip(ids_r.iter()).enumerate() { + let w_tok = dict_b.get(id_b).unwrap(); + let got = dict_r.get(id_r).unwrap(); + assert_eq!( + got, w_tok, + "{fname}: value {i} resolved to wrong branch (writer={w_tok}, got={got})" + ); + } + } + { + let fname = "map_of_union"; + let idx_b = baseline_schema.index_of(fname).unwrap(); + let idx_r = resolved.schema().index_of(fname).unwrap(); + let map_b = baseline + .column(idx_b) + .as_any() + .downcast_ref::() + .expect("map_of_union should be a Map"); + let map_r = resolved + .column(idx_r) + .as_any() + .downcast_ref::() + .expect("map_of_union should be a Map"); + assert_eq!( + map_b.value_offsets(), + map_r.value_offsets(), + "{fname}: map value offsets changed after resolution" + ); + let ent_b = map_b.entries(); + let ent_r = map_r.entries(); + let val_b_any = ent_b.column(1).as_ref(); + let val_r_any = ent_r.column(1).as_ref(); + let b_union = val_b_any.as_any().downcast_ref::(); + let r_union = val_r_any.as_any().downcast_ref::(); + if let (Some(u_b), Some(u_r)) = (b_union, r_union) { + assert_union_equivalent(fname, u_b, u_r); + } else { + assert_eq!( + val_b_any.data_type(), + val_r_any.data_type(), + "{fname}: value data types differ after resolution" + ); + assert_eq!( + val_b_any, val_r_any, + "{fname}: value arrays differ after resolution (nullable value column case)" + ); + let value_nullable = |m: &MapArray| -> bool { + match m.data_type() { + DataType::Map(entries_field, _sorted) => match entries_field.data_type() { + DataType::Struct(fields) => { + assert_eq!(fields.len(), 2, "entries struct must have 2 fields"); + assert_eq!(fields[0].name(), "key"); + assert_eq!(fields[1].name(), "value"); + fields[1].is_nullable() + } + other => panic!("Map entries field must be Struct, got {other:?}"), + }, + other => panic!("expected Map data type, got {other:?}"), + } + }; + assert!( + value_nullable(map_b), + "{fname}: baseline Map value field should be nullable per Arrow spec" + ); + assert!( + value_nullable(map_r), + "{fname}: resolved Map value field should be nullable per Arrow spec" + ); + } + } + { + let fname = "record_with_union_field"; + let idx_b = baseline_schema.index_of(fname).unwrap(); + let idx_r = resolved.schema().index_of(fname).unwrap(); + let rec_b = baseline + .column(idx_b) + .as_any() + .downcast_ref::() + .expect("record_with_union_field should be a Struct"); + let rec_r = resolved + .column(idx_r) + .as_any() + .downcast_ref::() + .expect("record_with_union_field should be a Struct"); + let u_b = rec_b + .column_by_name("u") + .unwrap() + .as_any() + .downcast_ref::() + .expect("field 'u' should be Union (baseline)"); + let u_r = rec_r + .column_by_name("u") + .unwrap() + .as_any() + .downcast_ref::() + .expect("field 'u' should be Union (resolved)"); + assert_union_equivalent("record_with_union_field.u", u_b, u_r); + } + } + #[test] fn test_read_zero_byte_avro_file() { let batch = read_file("test/data/zero_byte.avro", 3, false); @@ -2538,6 +3536,7 @@ mod test { let values_i128: Vec = (1..=24).map(|n| (n as i128) * pow10).collect(); let build_expected = |dt: &DataType, values: &[i128]| -> ArrayRef { match *dt { + #[cfg(feature = "small_decimals")] DataType::Decimal32(p, s) => { let it = values.iter().map(|&v| v as i32); Arc::new( @@ -2546,6 +3545,7 @@ mod test { .unwrap(), ) } + #[cfg(feature = "small_decimals")] DataType::Decimal64(p, s) => { let it = values.iter().map(|&v| v as i64); Arc::new( diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 48eb601024b5..188b9d486ff5 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -15,27 +15,27 @@ // specific language governing permissions and limitations // under the License. -use crate::codec::{AvroDataType, Codec, Promotion, ResolutionInfo}; +use crate::codec::{ + AvroDataType, AvroField, AvroLiteral, Codec, Promotion, ResolutionInfo, ResolvedRecord, + ResolvedUnion, +}; use crate::reader::block::{Block, BlockDecoder}; use crate::reader::cursor::AvroCursor; -use crate::reader::header::Header; -use crate::schema::*; +use crate::schema::Nullability; use arrow_array::builder::{ - ArrayBuilder, Decimal128Builder, Decimal256Builder, Decimal32Builder, Decimal64Builder, - IntervalMonthDayNanoBuilder, PrimitiveBuilder, + Decimal128Builder, Decimal256Builder, Decimal32Builder, Decimal64Builder, + IntervalMonthDayNanoBuilder, }; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; use arrow_schema::{ - ArrowError, DataType, Field as ArrowField, FieldRef, Fields, IntervalUnit, - Schema as ArrowSchema, SchemaRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, + UnionFields, UnionMode, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; #[cfg(feature = "small_decimals")] use arrow_schema::{DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION}; use std::cmp::Ordering; -use std::collections::HashMap; -use std::io::Read; use std::sync::Arc; use uuid::Uuid; @@ -60,6 +60,69 @@ macro_rules! flush_decimal { }}; } +/// Macro to append a default decimal value from two's-complement big-endian bytes +/// into the corresponding decimal builder, with compile-time constructed error text. +macro_rules! append_decimal_default { + ($lit:expr, $builder:expr, $N:literal, $Int:ty, $name:literal) => {{ + match $lit { + AvroLiteral::Bytes(b) => { + let ext = sign_cast_to::<$N>(b)?; + let val = <$Int>::from_be_bytes(ext); + $builder.append_value(val); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + concat!( + "Default for ", + $name, + " must be bytes (two's-complement big-endian)" + ) + .to_string(), + )), + } + }}; +} + +macro_rules! flush_union { + ($fields:expr, $type_ids:expr, $offsets:expr, $encodings:expr) => {{ + let encoding_arrays = $encodings + .iter_mut() + .map(|d| d.flush(None)) + .collect::, _>>()?; + let type_ids_buf: ScalarBuffer = flush_values($type_ids).into_iter().collect(); + let offsets_buf: ScalarBuffer = flush_values($offsets).into_iter().collect(); + let arr = UnionArray::try_new( + $fields.clone(), + type_ids_buf, + Some(offsets_buf), + encoding_arrays, + ) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Arc::new(arr) + }}; +} + +macro_rules! get_writer_union_action { + ($buf:expr, $union_resolution:expr) => {{ + let branch = $buf.get_long()?; + if branch < 0 { + return Err(ArrowError::ParseError(format!( + "Negative union branch index {branch}" + ))); + } + let idx = branch as usize; + let dispatch = match $union_resolution.dispatch.as_deref() { + Some(d) => d, + None => { + return Err(ArrowError::SchemaError( + "dispatch table missing for writer=union".to_string(), + )); + } + }; + (idx, *dispatch.get(idx).unwrap_or(&BranchDispatch::NoMatch)) + }}; +} + #[derive(Debug)] pub(crate) struct RecordDecoderBuilder<'a> { data_type: &'a AvroDataType, @@ -91,15 +154,7 @@ pub(crate) struct RecordDecoder { schema: SchemaRef, fields: Vec, use_utf8view: bool, - resolved: Option, -} - -#[derive(Debug)] -struct ResolvedRuntime { - /// writer field index -> reader field index (or None if writer-only) - writer_to_reader: Arc<[Option]>, - /// per-writer-field skipper (Some only when writer-only) - skip_decoders: Vec>, + projector: Option, } impl RecordDecoder { @@ -131,29 +186,25 @@ impl RecordDecoder { ) -> Result { match data_type.codec() { Codec::Struct(reader_fields) => { - // Build Arrow schema fields and per-child decoders let mut arrow_fields = Vec::with_capacity(reader_fields.len()); let mut encodings = Vec::with_capacity(reader_fields.len()); for avro_field in reader_fields.iter() { arrow_fields.push(avro_field.field()); encodings.push(Decoder::try_new(avro_field.data_type())?); } - // If this record carries resolution metadata, prepare top-level runtime helpers - let resolved = match data_type.resolution.as_ref() { - Some(ResolutionInfo::Record(rec)) => { - let skip_decoders = build_skip_decoders(&rec.skip_fields)?; - Some(ResolvedRuntime { - writer_to_reader: rec.writer_to_reader.clone(), - skip_decoders, - }) - } + let projector = match data_type.resolution.as_ref() { + Some(ResolutionInfo::Record(rec)) => Some( + ProjectorBuilder::try_new(rec) + .with_reader_fields(reader_fields) + .build()?, + ), _ => None, }; Ok(Self { schema: Arc::new(ArrowSchema::new(arrow_fields)), fields: encodings, use_utf8view, - resolved, + projector, }) } other => Err(ArrowError::ParseError(format!( @@ -170,17 +221,10 @@ impl RecordDecoder { /// Decode `count` records from `buf` pub(crate) fn decode(&mut self, buf: &[u8], count: usize) -> Result { let mut cursor = AvroCursor::new(buf); - match self.resolved.as_mut() { - Some(runtime) => { - // Top-level resolved record: read writer fields in writer order, - // project into reader fields, and skip writer-only fields + match self.projector.as_mut() { + Some(proj) => { for _ in 0..count { - decode_with_resolution( - &mut cursor, - &mut self.fields, - &runtime.writer_to_reader, - &mut runtime.skip_decoders, - )?; + proj.project_record(&mut cursor, &mut self.fields)?; } } None => { @@ -205,24 +249,152 @@ impl RecordDecoder { } } -fn decode_with_resolution( - buf: &mut AvroCursor<'_>, - encodings: &mut [Decoder], - writer_to_reader: &[Option], - skippers: &mut [Option], -) -> Result<(), ArrowError> { - for (w_idx, (target, skipper_opt)) in writer_to_reader.iter().zip(skippers).enumerate() { - match (*target, skipper_opt.as_mut()) { - (Some(r_idx), _) => encodings[r_idx].decode(buf)?, - (None, Some(sk)) => sk.skip(buf)?, - (None, None) => { - return Err(ArrowError::SchemaError(format!( - "No skipper available for writer-only field at index {w_idx}", - ))); +#[derive(Debug)] +struct EnumResolution { + mapping: Arc<[i32]>, + default_index: i32, +} + +#[derive(Debug, Clone, Copy)] +enum BranchDispatch { + NoMatch, + ToReader { + reader_idx: usize, + promotion: Promotion, + }, +} + +#[derive(Debug)] +struct UnionResolution { + dispatch: Option>, + kind: UnionResolvedKind, +} + +#[derive(Debug)] +enum UnionResolvedKind { + Both { + reader_type_codes: Arc<[i8]>, + }, + ToSingle { + target: Box, + }, + FromSingle { + reader_type_codes: Arc<[i8]>, + target_reader_index: usize, + promotion: Promotion, + }, +} + +#[derive(Debug, Default)] +struct UnionResolutionBuilder { + fields: Option, + resolved: Option, +} + +impl UnionResolutionBuilder { + #[inline] + fn new() -> Self { + Self { + fields: None, + resolved: None, + } + } + + #[inline] + fn with_fields(mut self, fields: UnionFields) -> Self { + self.fields = Some(fields); + self + } + + #[inline] + fn with_resolved_union(mut self, resolved_union: &ResolvedUnion) -> Self { + self.resolved = Some(resolved_union.clone()); + self + } + + fn build(self) -> Result { + let info = self.resolved.ok_or_else(|| { + ArrowError::InvalidArgumentError( + "UnionResolutionBuilder requires resolved_union to be provided".to_string(), + ) + })?; + match (info.writer_is_union, info.reader_is_union) { + (true, true) => { + let fields = self.fields.ok_or_else(|| { + ArrowError::InvalidArgumentError( + "UnionResolutionBuilder for reader union requires fields".to_string(), + ) + })?; + let reader_type_codes: Vec = + fields.iter().map(|(tid, _)| tid).collect::>(); + let dispatch: Vec = info + .writer_to_reader + .iter() + .map(|m| match m { + Some((reader_index, promotion)) => BranchDispatch::ToReader { + reader_idx: *reader_index, + promotion: *promotion, + }, + None => BranchDispatch::NoMatch, + }) + .collect(); + Ok(UnionResolution { + dispatch: Some(Arc::from(dispatch)), + kind: UnionResolvedKind::Both { + reader_type_codes: Arc::from(reader_type_codes), + }, + }) + } + (false, true) => { + let fields = self.fields.ok_or_else(|| { + ArrowError::InvalidArgumentError( + "UnionResolutionBuilder for reader union requires fields".to_string(), + ) + })?; + let reader_type_codes: Vec = + fields.iter().map(|(tid, _)| tid).collect::>(); + let (target_reader_index, promotion) = + match info.writer_to_reader.first().and_then(|x| *x) { + Some(pair) => pair, + None => { + return Err(ArrowError::SchemaError( + "Writer schema does not match any reader union branch".to_string(), + )) + } + }; + Ok(UnionResolution { + dispatch: None, + kind: UnionResolvedKind::FromSingle { + reader_type_codes: Arc::from(reader_type_codes), + target_reader_index, + promotion, + }, + }) + } + (true, false) => { + let dispatch: Vec = info + .writer_to_reader + .iter() + .map(|m| match m { + Some((reader_index, promotion)) => BranchDispatch::ToReader { + reader_idx: *reader_index, + promotion: *promotion, + }, + None => BranchDispatch::NoMatch, + }) + .collect(); + Ok(UnionResolution { + dispatch: Some(Arc::from(dispatch)), + kind: UnionResolvedKind::ToSingle { + target: Box::new(Decoder::Null(0)), + }, + }) } + (false, false) => Err(ArrowError::InvalidArgumentError( + "UnionResolutionBuilder used for non-union case".to_string(), + )), } } - Ok(()) } #[derive(Debug)] @@ -252,7 +424,7 @@ enum Decoder { /// String data encoded as UTF-8 bytes, but mapped to Arrow's StringViewArray StringView(OffsetBufferBuilder, Vec), Array(FieldRef, OffsetBufferBuilder, Box), - Record(Fields, Vec), + Record(Fields, Vec, Option), Map( FieldRef, OffsetBufferBuilder, @@ -261,32 +433,59 @@ enum Decoder { Box, ), Fixed(i32, Vec), - Enum(Vec, Arc<[String]>), + Enum(Vec, Arc<[String]>, Option), Duration(IntervalMonthDayNanoBuilder), Uuid(Vec), + #[cfg(feature = "small_decimals")] Decimal32(usize, Option, Option, Decimal32Builder), + #[cfg(feature = "small_decimals")] Decimal64(usize, Option, Option, Decimal64Builder), Decimal128(usize, Option, Option, Decimal128Builder), Decimal256(usize, Option, Option, Decimal256Builder), + Union( + UnionFields, + Vec, + Vec, + Vec, + Vec, + Option, + ), Nullable(Nullability, NullBufferBuilder, Box), - EnumResolved { - indices: Vec, - symbols: Arc<[String]>, - mapping: Arc<[i32]>, - default_index: i32, - }, - /// Resolved record that needs writer->reader projection and skipping writer-only fields - RecordResolved { - fields: Fields, - encodings: Vec, - writer_to_reader: Arc<[Option]>, - skip_decoders: Vec>, - }, } impl Decoder { fn try_new(data_type: &AvroDataType) -> Result { - // Extract just the Promotion (if any) to simplify pattern matching + if let Some(ResolutionInfo::Union(info)) = data_type.resolution.as_ref() { + if info.writer_is_union && !info.reader_is_union { + let mut clone = data_type.clone(); + clone.resolution = None; + let target = Box::new(Self::try_new_internal(&clone)?); + let mut union_resolution = UnionResolutionBuilder::new() + .with_resolved_union(info) + .build()?; + if let UnionResolvedKind::ToSingle { target: t } = &mut union_resolution.kind { + *t = target; + } + let base = Self::Union( + UnionFields::empty(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Some(union_resolution), + ); + return Ok(match data_type.nullability() { + Some(n) => { + Self::Nullable(n, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(base)) + } + None => base, + }); + } + } + Self::try_new_internal(data_type) + } + + fn try_new_internal(data_type: &AvroDataType) -> Result { let promotion = match data_type.resolution.as_ref() { Some(ResolutionInfo::Promotion(p)) => Some(p), _ => None, @@ -403,16 +602,14 @@ impl Decoder { ) } (Codec::Enum(symbols), _) => { - if let Some(ResolutionInfo::EnumMapping(mapping)) = data_type.resolution.as_ref() { - Self::EnumResolved { - indices: Vec::with_capacity(DEFAULT_CAPACITY), - symbols: symbols.clone(), + let res = match data_type.resolution.as_ref() { + Some(ResolutionInfo::EnumMapping(mapping)) => Some(EnumResolution { mapping: mapping.mapping.clone(), default_index: mapping.default_index, - } - } else { - Self::Enum(Vec::with_capacity(DEFAULT_CAPACITY), symbols.clone()) - } + }), + _ => None, + }; + Self::Enum(Vec::with_capacity(DEFAULT_CAPACITY), symbols.clone(), res) } (Codec::Struct(fields), _) => { let mut arrow_fields = Vec::with_capacity(fields.len()); @@ -422,17 +619,17 @@ impl Decoder { arrow_fields.push(avro_field.field()); encodings.push(encoding); } - if let Some(ResolutionInfo::Record(rec)) = data_type.resolution.as_ref() { - let skip_decoders = build_skip_decoders(&rec.skip_fields)?; - Self::RecordResolved { - fields: arrow_fields.into(), - encodings, - writer_to_reader: rec.writer_to_reader.clone(), - skip_decoders, - } - } else { - Self::Record(arrow_fields.into(), encodings) - } + let projector = + if let Some(ResolutionInfo::Record(rec)) = data_type.resolution.as_ref() { + Some( + ProjectorBuilder::try_new(rec) + .with_reader_fields(fields) + .build()?, + ) + } else { + None + }; + Self::Record(arrow_fields.into(), encodings, projector) } (Codec::Map(child), _) => { let val_field = child.field_with_name("value"); @@ -453,6 +650,34 @@ impl Decoder { Box::new(val_dec), ) } + (Codec::Union(encodings, fields, mode), _) => { + if *mode != UnionMode::Dense { + return Err(ArrowError::NotYetImplemented( + "Sparse Arrow unions are not yet supported".to_string(), + )); + } + let mut decoders = Vec::with_capacity(encodings.len()); + for c in encodings.iter() { + decoders.push(Self::try_new_internal(c)?); + } + let union_resolution = match data_type.resolution.as_ref() { + Some(ResolutionInfo::Union(info)) if info.reader_is_union => Some( + UnionResolutionBuilder::new() + .with_fields(fields.clone()) + .with_resolved_union(info) + .build()?, + ), + _ => None, + }; + Self::Union( + fields.clone(), + Vec::with_capacity(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + decoders, + vec![0; encodings.len()], + union_resolution, + ) + } (Codec::Uuid, _) => Self::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)), }; Ok(match data_type.nullability() { @@ -491,30 +716,389 @@ impl Decoder { Self::Uuid(v) => { v.extend([0; 16]); } - Self::Array(_, offsets, e) => { + Self::Array(_, offsets, _e) => { offsets.push_length(0); } - Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()), + Self::Record(_, e, _) => e.iter_mut().for_each(|e| e.append_null()), Self::Map(_, _koff, moff, _, _) => { moff.push_length(0); } Self::Fixed(sz, accum) => { accum.extend(std::iter::repeat_n(0u8, *sz as usize)); } + #[cfg(feature = "small_decimals")] Self::Decimal32(_, _, _, builder) => builder.append_value(0), + #[cfg(feature = "small_decimals")] Self::Decimal64(_, _, _, builder) => builder.append_value(0), Self::Decimal128(_, _, _, builder) => builder.append_value(0), Self::Decimal256(_, _, _, builder) => builder.append_value(i256::ZERO), - Self::Enum(indices, _) => indices.push(0), - Self::EnumResolved { indices, .. } => indices.push(0), + Self::Enum(indices, _, _) => indices.push(0), Self::Duration(builder) => builder.append_null(), + Self::Union(fields, type_ids, offsets, encodings, encoding_counts, None) => { + let mut chosen = None; + for (i, ch) in encodings.iter().enumerate() { + if matches!(ch, Decoder::Null(_)) { + chosen = Some(i); + break; + } + } + let idx = chosen.unwrap_or(0); + let type_id = fields + .iter() + .nth(idx) + .map(|(type_id, _)| type_id) + .unwrap_or_else(|| i8::try_from(idx).unwrap_or(0)); + type_ids.push(type_id); + offsets.push(encoding_counts[idx]); + encodings[idx].append_null(); + encoding_counts[idx] += 1; + } + Self::Union( + fields, + type_ids, + offsets, + encodings, + encoding_counts, + Some(union_resolution), + ) => match &mut union_resolution.kind { + UnionResolvedKind::Both { .. } => { + let mut chosen = None; + for (i, ch) in encodings.iter().enumerate() { + if matches!(ch, Decoder::Null(_)) { + chosen = Some(i); + break; + } + } + let idx = chosen.unwrap_or(0); + let type_id = fields + .iter() + .nth(idx) + .map(|(type_id, _)| type_id) + .unwrap_or_else(|| i8::try_from(idx).unwrap_or(0)); + type_ids.push(type_id); + offsets.push(encoding_counts[idx]); + encodings[idx].append_null(); + encoding_counts[idx] += 1; + } + UnionResolvedKind::ToSingle { target } => { + target.append_null(); + } + UnionResolvedKind::FromSingle { + target_reader_index, + .. + } => { + let type_id = fields + .iter() + .nth(*target_reader_index) + .map(|(type_id, _)| type_id) + .unwrap_or(0); + type_ids.push(type_id); + offsets.push(encoding_counts[*target_reader_index]); + encodings[*target_reader_index].append_null(); + encoding_counts[*target_reader_index] += 1; + } + }, Self::Nullable(_, null_buffer, inner) => { null_buffer.append(false); inner.append_null(); } - Self::RecordResolved { encodings, .. } => { - encodings.iter_mut().for_each(|e| e.append_null()); + } + } + + /// Append a single default literal into the decoder's buffers + fn append_default(&mut self, lit: &AvroLiteral) -> Result<(), ArrowError> { + match self { + Self::Nullable(_, nb, inner) => { + if matches!(lit, AvroLiteral::Null) { + nb.append(false); + inner.append_null(); + Ok(()) + } else { + nb.append(true); + inner.append_default(lit) + } + } + Self::Null(count) => match lit { + AvroLiteral::Null => { + *count += 1; + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Non-null default for null type".to_string(), + )), + }, + Self::Boolean(b) => match lit { + AvroLiteral::Boolean(v) => { + b.append(*v); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for boolean must be boolean".to_string(), + )), + }, + Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => match lit { + AvroLiteral::Int(i) => { + v.push(*i); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for int32/date32/time-millis must be int".to_string(), + )), + }, + Self::Int64(v) + | Self::Int32ToInt64(v) + | Self::TimeMicros(v) + | Self::TimestampMillis(_, v) + | Self::TimestampMicros(_, v) => match lit { + AvroLiteral::Long(i) => { + v.push(*i); + Ok(()) + } + AvroLiteral::Int(i) => { + v.push(*i as i64); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for long/time-micros/timestamp must be long or int".to_string(), + )), + }, + Self::Float32(v) | Self::Int32ToFloat32(v) | Self::Int64ToFloat32(v) => match lit { + AvroLiteral::Float(f) => { + v.push(*f); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for float must be float".to_string(), + )), + }, + Self::Float64(v) + | Self::Int32ToFloat64(v) + | Self::Int64ToFloat64(v) + | Self::Float32ToFloat64(v) => match lit { + AvroLiteral::Double(f) => { + v.push(*f); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for double must be double".to_string(), + )), + }, + Self::Binary(offsets, values) | Self::StringToBytes(offsets, values) => match lit { + AvroLiteral::Bytes(b) => { + offsets.push_length(b.len()); + values.extend_from_slice(b); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for bytes must be bytes".to_string(), + )), + }, + Self::BytesToString(offsets, values) + | Self::String(offsets, values) + | Self::StringView(offsets, values) => match lit { + AvroLiteral::String(s) => { + let b = s.as_bytes(); + offsets.push_length(b.len()); + values.extend_from_slice(b); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for string must be string".to_string(), + )), + }, + Self::Uuid(values) => match lit { + AvroLiteral::String(s) => { + let uuid = Uuid::try_parse(s).map_err(|e| { + ArrowError::InvalidArgumentError(format!("Invalid UUID default: {s} ({e})")) + })?; + values.extend_from_slice(uuid.as_bytes()); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for uuid must be string".to_string(), + )), + }, + Self::Fixed(sz, accum) => match lit { + AvroLiteral::Bytes(b) => { + if b.len() != *sz as usize { + return Err(ArrowError::InvalidArgumentError(format!( + "Fixed default length {} does not match size {sz}", + b.len(), + ))); + } + accum.extend_from_slice(b); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for fixed must be bytes".to_string(), + )), + }, + #[cfg(feature = "small_decimals")] + Self::Decimal32(_, _, _, builder) => { + append_decimal_default!(lit, builder, 4, i32, "decimal32") + } + #[cfg(feature = "small_decimals")] + Self::Decimal64(_, _, _, builder) => { + append_decimal_default!(lit, builder, 8, i64, "decimal64") + } + Self::Decimal128(_, _, _, builder) => { + append_decimal_default!(lit, builder, 16, i128, "decimal128") + } + Self::Decimal256(_, _, _, builder) => { + append_decimal_default!(lit, builder, 32, i256, "decimal256") } + Self::Duration(builder) => match lit { + AvroLiteral::Bytes(b) => { + if b.len() != 12 { + return Err(ArrowError::InvalidArgumentError(format!( + "Duration default must be exactly 12 bytes, got {}", + b.len() + ))); + } + let months = u32::from_le_bytes([b[0], b[1], b[2], b[3]]); + let days = u32::from_le_bytes([b[4], b[5], b[6], b[7]]); + let millis = u32::from_le_bytes([b[8], b[9], b[10], b[11]]); + let nanos = (millis as i64) * 1_000_000; + builder.append_value(IntervalMonthDayNano::new( + months as i32, + days as i32, + nanos, + )); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for duration must be 12-byte little-endian months/days/millis" + .to_string(), + )), + }, + Self::Array(_, offsets, inner) => match lit { + AvroLiteral::Array(items) => { + offsets.push_length(items.len()); + for item in items { + inner.append_default(item)?; + } + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for array must be an array literal".to_string(), + )), + }, + Self::Map(_, koff, moff, kdata, valdec) => match lit { + AvroLiteral::Map(entries) => { + moff.push_length(entries.len()); + for (k, v) in entries { + let kb = k.as_bytes(); + koff.push_length(kb.len()); + kdata.extend_from_slice(kb); + valdec.append_default(v)?; + } + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for map must be a map/object literal".to_string(), + )), + }, + Self::Enum(indices, symbols, _) => match lit { + AvroLiteral::Enum(sym) => { + let pos = symbols.iter().position(|s| s == sym).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Enum default symbol {sym:?} not in reader symbols" + )) + })?; + indices.push(pos as i32); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for enum must be a symbol".to_string(), + )), + }, + Self::Union(fields, type_ids, offsets, encodings, encoding_counts, None) => { + if encodings.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "Union default cannot be applied to empty union".to_string(), + )); + } + let type_id = fields + .iter() + .nth(0) + .map(|(type_id, _)| type_id) + .unwrap_or(0_i8); + type_ids.push(type_id); + offsets.push(encoding_counts[0]); + encodings[0].append_default(lit)?; + encoding_counts[0] += 1; + Ok(()) + } + Self::Union( + fields, + type_ids, + offsets, + encodings, + encoding_counts, + Some(union_resolution), + ) => match &mut union_resolution.kind { + UnionResolvedKind::Both { .. } => { + if encodings.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "Union default cannot be applied to empty union".to_string(), + )); + } + let type_id = fields + .iter() + .nth(0) + .map(|(type_id, _)| type_id) + .unwrap_or(0_i8); + type_ids.push(type_id); + offsets.push(encoding_counts[0]); + encodings[0].append_default(lit)?; + encoding_counts[0] += 1; + Ok(()) + } + UnionResolvedKind::ToSingle { target } => target.append_default(lit), + UnionResolvedKind::FromSingle { + target_reader_index, + .. + } => { + let type_id = fields + .iter() + .nth(*target_reader_index) + .map(|(type_id, _)| type_id) + .unwrap_or(0_i8); + type_ids.push(type_id); + offsets.push(encoding_counts[*target_reader_index]); + encodings[*target_reader_index].append_default(lit)?; + encoding_counts[*target_reader_index] += 1; + Ok(()) + } + }, + Self::Record(field_meta, decoders, projector) => match lit { + AvroLiteral::Map(entries) => { + for (i, dec) in decoders.iter_mut().enumerate() { + let name = field_meta[i].name(); + if let Some(sub) = entries.get(name) { + dec.append_default(sub)?; + } else if let Some(proj) = projector.as_ref() { + proj.project_default(dec, i)?; + } else { + dec.append_null(); + } + } + Ok(()) + } + AvroLiteral::Null => { + for (i, dec) in decoders.iter_mut().enumerate() { + if let Some(proj) = projector.as_ref() { + proj.project_default(dec, i)?; + } else { + dec.append_null(); + } + } + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for record must be a map/object or null".to_string(), + )), + }, } } @@ -560,11 +1144,14 @@ impl Decoder { let total_items = read_blocks(buf, |cursor| encoding.decode(cursor))?; off.push_length(total_items); } - Self::Record(_, encodings) => { + Self::Record(_, encodings, None) => { for encoding in encodings { encoding.decode(buf)?; } } + Self::Record(_, encodings, Some(proj)) => { + proj.project_record(buf, encodings)?; + } Self::Map(_, koff, moff, kdata, valdec) => { let newly_added = read_blocks(buf, |cur| { let kb = cur.get_bytes()?; @@ -578,9 +1165,11 @@ impl Decoder { let fx = buf.get_fixed(*sz as usize)?; accum.extend_from_slice(fx); } + #[cfg(feature = "small_decimals")] Self::Decimal32(_, _, size, builder) => { decode_decimal!(size, buf, builder, 4, i32); } + #[cfg(feature = "small_decimals")] Self::Decimal64(_, _, size, builder) => { decode_decimal!(size, buf, builder, 8, i64); } @@ -590,21 +1179,16 @@ impl Decoder { Self::Decimal256(_, _, size, builder) => { decode_decimal!(size, buf, builder, 32, i256); } - Self::Enum(indices, _) => { + Self::Enum(indices, _, None) => { indices.push(buf.get_int()?); } - Self::EnumResolved { - indices, - mapping, - default_index, - .. - } => { + Self::Enum(indices, _, Some(res)) => { let raw = buf.get_int()?; let resolved = usize::try_from(raw) .ok() - .and_then(|idx| mapping.get(idx).copied()) + .and_then(|idx| res.mapping.get(idx).copied()) .filter(|&idx| idx >= 0) - .unwrap_or(*default_index); + .unwrap_or(res.default_index); if resolved >= 0 { indices.push(resolved); } else { @@ -621,8 +1205,88 @@ impl Decoder { let nanos = (millis as i64) * 1_000_000; builder.append_value(IntervalMonthDayNano::new(months as i32, days as i32, nanos)); } + Self::Union(fields, type_ids, offsets, encodings, encoding_counts, None) => { + let branch = buf.get_long()?; + if branch < 0 { + return Err(ArrowError::ParseError(format!( + "Negative union branch index {branch}" + ))); + } + let idx = branch as usize; + if idx >= encodings.len() { + return Err(ArrowError::ParseError(format!( + "Union branch index {idx} out of range ({} branches)", + encodings.len() + ))); + } + let type_id = fields + .iter() + .nth(idx) + .map(|(type_id, _)| type_id) + .unwrap_or_else(|| i8::try_from(idx).unwrap_or(0)); + type_ids.push(type_id); + offsets.push(encoding_counts[idx]); + encodings[idx].decode(buf)?; + encoding_counts[idx] += 1; + } + Self::Union( + _, + type_ids, + offsets, + encodings, + encoding_counts, + Some(union_resolution), + ) => match &mut union_resolution.kind { + UnionResolvedKind::Both { + reader_type_codes, .. + } => { + let (idx, action) = get_writer_union_action!(buf, union_resolution); + match action { + BranchDispatch::NoMatch => { + return Err(ArrowError::ParseError(format!( + "Union branch index {idx} not resolvable by reader schema" + ))); + } + BranchDispatch::ToReader { + reader_idx, + promotion, + } => { + let type_id = reader_type_codes[reader_idx]; + type_ids.push(type_id); + offsets.push(encoding_counts[reader_idx]); + encodings[reader_idx].decode_with_promotion(buf, promotion)?; + encoding_counts[reader_idx] += 1; + } + } + } + UnionResolvedKind::ToSingle { target } => { + let (idx, action) = get_writer_union_action!(buf, union_resolution); + match action { + BranchDispatch::NoMatch => { + return Err(ArrowError::ParseError(format!( + "Writer union branch {idx} does not resolve to reader type" + ))); + } + BranchDispatch::ToReader { promotion, .. } => { + target.decode_with_promotion(buf, promotion)?; + } + } + } + UnionResolvedKind::FromSingle { + reader_type_codes, + target_reader_index, + promotion, + .. + } => { + let type_id = reader_type_codes[*target_reader_index]; + type_ids.push(type_id); + offsets.push(encoding_counts[*target_reader_index]); + encodings[*target_reader_index].decode_with_promotion(buf, *promotion)?; + encoding_counts[*target_reader_index] += 1; + } + }, Self::Nullable(order, nb, encoding) => { - let branch = buf.read_vlq()?; + let branch = buf.get_long()?; let is_not_null = match *order { Nullability::NullFirst => branch != 0, Nullability::NullSecond => branch == 0, @@ -635,18 +1299,98 @@ impl Decoder { } nb.append(is_not_null); } - Self::RecordResolved { - encodings, - writer_to_reader, - skip_decoders, - .. - } => { - decode_with_resolution(buf, encodings, writer_to_reader, skip_decoders)?; - } } Ok(()) } + fn decode_with_promotion( + &mut self, + buf: &mut AvroCursor<'_>, + promotion: Promotion, + ) -> Result<(), ArrowError> { + match promotion { + Promotion::Direct => self.decode(buf), + Promotion::IntToLong => match self { + Self::Int64(v) => { + v.push(buf.get_int()? as i64); + Ok(()) + } + _ => Err(ArrowError::ParseError( + "Promotion Int->Long target mismatch".into(), + )), + }, + Promotion::IntToFloat => match self { + Self::Float32(v) => { + v.push(buf.get_int()? as f32); + Ok(()) + } + _ => Err(ArrowError::ParseError( + "Promotion Int->Float target mismatch".into(), + )), + }, + Promotion::IntToDouble => match self { + Self::Float64(v) => { + v.push(buf.get_int()? as f64); + Ok(()) + } + _ => Err(ArrowError::ParseError( + "Promotion Int->Double target mismatch".into(), + )), + }, + Promotion::LongToFloat => match self { + Self::Float32(v) => { + v.push(buf.get_long()? as f32); + Ok(()) + } + _ => Err(ArrowError::ParseError( + "Promotion Long->Float target mismatch".into(), + )), + }, + Promotion::LongToDouble => match self { + Self::Float64(v) => { + v.push(buf.get_long()? as f64); + Ok(()) + } + _ => Err(ArrowError::ParseError( + "Promotion Long->Double target mismatch".into(), + )), + }, + Promotion::FloatToDouble => match self { + Self::Float64(v) => { + v.push(buf.get_float()? as f64); + Ok(()) + } + _ => Err(ArrowError::ParseError( + "Promotion Float->Double target mismatch".into(), + )), + }, + Promotion::StringToBytes => match self { + Self::Binary(offsets, values) | Self::StringToBytes(offsets, values) => { + let data = buf.get_bytes()?; + offsets.push_length(data.len()); + values.extend_from_slice(data); + Ok(()) + } + _ => Err(ArrowError::ParseError( + "Promotion String->Bytes target mismatch".into(), + )), + }, + Promotion::BytesToString => match self { + Self::String(offsets, values) + | Self::StringView(offsets, values) + | Self::BytesToString(offsets, values) => { + let data = buf.get_bytes()?; + offsets.push_length(data.len()); + values.extend_from_slice(data); + Ok(()) + } + _ => Err(ArrowError::ParseError( + "Promotion Bytes->String target mismatch".into(), + )), + }, + } + } + /// Flush decoded records to an [`ArrayRef`] fn flush(&mut self, nulls: Option) -> Result { Ok(match self { @@ -711,7 +1455,7 @@ impl Decoder { let offsets = flush_offsets(offsets); Arc::new(ListArray::new(field.clone(), offsets, values, nulls)) } - Self::Record(fields, encodings) => { + Self::Record(fields, encodings, _) => { let arrays = encodings .iter_mut() .map(|x| x.flush(None)) @@ -764,9 +1508,11 @@ impl Decoder { .map_err(|e| ArrowError::ParseError(e.to_string()))?; Arc::new(arr) } + #[cfg(feature = "small_decimals")] Self::Decimal32(precision, scale, _, builder) => { flush_decimal!(builder, precision, scale, nulls, Decimal32Array) } + #[cfg(feature = "small_decimals")] Self::Decimal64(precision, scale, _, builder) => { flush_decimal!(builder, precision, scale, nulls, Decimal64Array) } @@ -776,24 +1522,23 @@ impl Decoder { Self::Decimal256(precision, scale, _, builder) => { flush_decimal!(builder, precision, scale, nulls, Decimal256Array) } - Self::Enum(indices, symbols) => flush_dict(indices, symbols, nulls)?, - Self::EnumResolved { - indices, symbols, .. - } => flush_dict(indices, symbols, nulls)?, + Self::Enum(indices, symbols, _) => flush_dict(indices, symbols, nulls)?, Self::Duration(builder) => { let (_, vals, _) = builder.finish().into_parts(); let vals = IntervalMonthDayNanoArray::try_new(vals, nulls) .map_err(|e| ArrowError::ParseError(e.to_string()))?; Arc::new(vals) } - Self::RecordResolved { - fields, encodings, .. - } => { - let arrays = encodings - .iter_mut() - .map(|x| x.flush(None)) - .collect::, _>>()?; - Arc::new(StructArray::new(fields.clone(), arrays, nulls)) + Self::Union(fields, type_ids, offsets, encodings, _, None) => { + flush_union!(fields, type_ids, offsets, encodings) + } + Self::Union(fields, type_ids, offsets, encodings, _, Some(union_resolution)) => { + match &mut union_resolution.kind { + UnionResolvedKind::Both { .. } | UnionResolvedKind::FromSingle { .. } => { + flush_union!(fields, type_ids, offsets, encodings) + } + UnionResolvedKind::ToSingle { target } => target.flush(nulls)?, + } } }) } @@ -976,6 +1721,121 @@ fn sign_cast_to(raw: &[u8]) -> Result<[u8; N], ArrowError> { Ok(out) } +#[derive(Debug)] +struct Projector { + writer_to_reader: Arc<[Option]>, + skip_decoders: Vec>, + field_defaults: Vec>, + default_injections: Arc<[(usize, AvroLiteral)]>, +} + +#[derive(Debug)] +struct ProjectorBuilder<'a> { + rec: &'a ResolvedRecord, + reader_fields: Option>, +} + +impl<'a> ProjectorBuilder<'a> { + #[inline] + fn try_new(rec: &'a ResolvedRecord) -> Self { + Self { + rec, + reader_fields: None, + } + } + + #[inline] + fn with_reader_fields(mut self, reader_fields: &Arc<[AvroField]>) -> Self { + self.reader_fields = Some(reader_fields.clone()); + self + } + + #[inline] + fn build(self) -> Result { + let reader_fields = self.reader_fields.ok_or_else(|| { + ArrowError::InvalidArgumentError( + "ProjectorBuilder requires reader_fields to be provided".to_string(), + ) + })?; + let mut field_defaults: Vec> = Vec::with_capacity(reader_fields.len()); + for avro_field in reader_fields.iter() { + if let Some(ResolutionInfo::DefaultValue(lit)) = + avro_field.data_type().resolution.as_ref() + { + field_defaults.push(Some(lit.clone())); + } else { + field_defaults.push(None); + } + } + let mut default_injections: Vec<(usize, AvroLiteral)> = + Vec::with_capacity(self.rec.default_fields.len()); + for &idx in self.rec.default_fields.iter() { + let lit = field_defaults + .get(idx) + .and_then(|lit| lit.clone()) + .unwrap_or(AvroLiteral::Null); + default_injections.push((idx, lit)); + } + let mut skip_decoders: Vec> = + Vec::with_capacity(self.rec.skip_fields.len()); + for datatype in self.rec.skip_fields.iter() { + let skipper = match datatype { + Some(datatype) => Some(Skipper::from_avro(datatype)?), + None => None, + }; + skip_decoders.push(skipper); + } + Ok(Projector { + writer_to_reader: self.rec.writer_to_reader.clone(), + skip_decoders, + field_defaults, + default_injections: default_injections.into(), + }) + } +} + +impl Projector { + #[inline] + fn project_default(&self, decoder: &mut Decoder, index: usize) -> Result<(), ArrowError> { + if let Some(default_literal) = self.field_defaults[index].as_ref() { + decoder.append_default(default_literal) + } else { + decoder.append_null(); + Ok(()) + } + } + + #[inline] + fn project_record( + &mut self, + buf: &mut AvroCursor<'_>, + encodings: &mut [Decoder], + ) -> Result<(), ArrowError> { + let n_writer = self.writer_to_reader.len(); + let n_injections = self.default_injections.len(); + for index in 0..(n_writer + n_injections) { + if index < n_writer { + match ( + self.writer_to_reader[index], + self.skip_decoders[index].as_mut(), + ) { + (Some(reader_index), _) => encodings[reader_index].decode(buf)?, + (None, Some(skipper)) => skipper.skip(buf)?, + (None, None) => { + return Err(ArrowError::SchemaError(format!( + "No skipper available for writer-only field at index {index}", + ))); + } + } + } else { + let (reader_index, ref lit) = self.default_injections[index - n_writer]; + encodings[reader_index].append_default(lit)?; + } + } + Ok(()) + } +} + /// Lightweight skipper for non‑projected writer fields /// (fields present in the writer schema but omitted by the reader/projection); /// per Avro 1.11.1 schema resolution these fields are ignored. @@ -1004,6 +1864,7 @@ enum Skipper { List(Box), Map(Box), Struct(Vec), + Union(Vec), Nullable(Nullability, Box), } @@ -1034,6 +1895,12 @@ impl Skipper { ), Codec::Map(values) => Self::Map(Box::new(Skipper::from_avro(values)?)), Codec::Interval => Self::DurationFixed12, + Codec::Union(encodings, _, _) => Self::Union( + encodings + .iter() + .map(Skipper::from_avro) + .collect::>()?, + ), _ => { return Err(ArrowError::NotYetImplemented(format!( "Skipper not implemented for codec {:?}", @@ -1111,8 +1978,26 @@ impl Skipper { } Ok(()) } + Self::Union(encodings) => { + // Union tag must be ZigZag-decoded + let branch = buf.get_long()?; + if branch < 0 { + return Err(ArrowError::ParseError(format!( + "Negative union branch index {branch}" + ))); + } + let idx = branch as usize; + if let Some(encoding) = encodings.get_mut(idx) { + encoding.skip(buf) + } else { + Err(ArrowError::ParseError(format!( + "Union branch index {idx} out of range for skipper ({} branches)", + encodings.len() + ))) + } + } Self::Nullable(order, inner) => { - let branch = buf.read_vlq()?; + let branch = buf.get_long()?; let is_not_null = match *order { Nullability::NullFirst => branch != 0, Nullability::NullSecond => branch == 0, @@ -1126,20 +2011,11 @@ impl Skipper { } } -#[inline] -fn build_skip_decoders( - skip_fields: &[Option], -) -> Result>, ArrowError> { - skip_fields - .iter() - .map(|opt| opt.as_ref().map(Skipper::from_avro).transpose()) - .collect() -} - #[cfg(test)] mod tests { use super::*; use crate::codec::AvroField; + use crate::schema::{PrimitiveType, Schema, TypeName}; use arrow_array::{ cast::AsArray, Array, Decimal128Array, Decimal256Array, Decimal32Array, DictionaryArray, FixedSizeBinaryArray, IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, @@ -1190,6 +2066,142 @@ mod tests { Decoder::try_new(field.data_type()).unwrap() } + #[test] + fn test_union_resolution_writer_union_reader_union_reorder_and_promotion_dense() { + let ws = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + ]); + let rs = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + ]); + let field = AvroField::resolve_from_writer_and_reader(&ws, &rs, false, false).unwrap(); + let mut dec = Decoder::try_new(field.data_type()).unwrap(); + let mut rec1 = encode_avro_long(0); + rec1.extend(encode_avro_int(7)); + let mut cur1 = AvroCursor::new(&rec1); + dec.decode(&mut cur1).unwrap(); + let mut rec2 = encode_avro_long(1); + rec2.extend(encode_avro_bytes("abc".as_bytes())); + let mut cur2 = AvroCursor::new(&rec2); + dec.decode(&mut cur2).unwrap(); + let arr = dec.flush(None).unwrap(); + let ua = arr + .as_any() + .downcast_ref::() + .expect("dense union output"); + assert_eq!( + ua.type_id(0), + 1, + "first value must select reader 'long' branch" + ); + assert_eq!(ua.value_offset(0), 0); + assert_eq!( + ua.type_id(1), + 0, + "second value must select reader 'string' branch" + ); + assert_eq!(ua.value_offset(1), 0); + let long_child = ua.child(1).as_any().downcast_ref::().unwrap(); + assert_eq!(long_child.len(), 1); + assert_eq!(long_child.value(0), 7); + let str_child = ua.child(0).as_any().downcast_ref::().unwrap(); + assert_eq!(str_child.len(), 1); + assert_eq!(str_child.value(0), "abc"); + } + + #[test] + fn test_union_resolution_writer_union_reader_nonunion_promotion_int_to_long() { + let ws = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + ]); + let rs = Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)); + let field = AvroField::resolve_from_writer_and_reader(&ws, &rs, false, false).unwrap(); + let mut dec = Decoder::try_new(field.data_type()).unwrap(); + let mut data = encode_avro_long(0); + data.extend(encode_avro_int(5)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + let arr = dec.flush(None).unwrap(); + let out = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(out.len(), 1); + assert_eq!(out.value(0), 5); + } + + #[test] + fn test_union_resolution_writer_union_reader_nonunion_mismatch_errors() { + let ws = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + ]); + let rs = Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)); + let field = AvroField::resolve_from_writer_and_reader(&ws, &rs, false, false).unwrap(); + let mut dec = Decoder::try_new(field.data_type()).unwrap(); + let mut data = encode_avro_long(1); + data.extend(encode_avro_bytes("z".as_bytes())); + let mut cur = AvroCursor::new(&data); + let res = dec.decode(&mut cur); + assert!( + res.is_err(), + "expected error when writer union branch does not resolve to reader non-union type" + ); + } + + #[test] + fn test_union_resolution_writer_nonunion_reader_union_selects_matching_branch() { + let ws = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + let rs = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + ]); + let field = AvroField::resolve_from_writer_and_reader(&ws, &rs, false, false).unwrap(); + let mut dec = Decoder::try_new(field.data_type()).unwrap(); + let data = encode_avro_int(6); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + let arr = dec.flush(None).unwrap(); + let ua = arr + .as_any() + .downcast_ref::() + .expect("dense union output"); + assert_eq!(ua.len(), 1); + assert_eq!( + ua.type_id(0), + 1, + "must resolve to reader 'long' branch (type_id 1)" + ); + assert_eq!(ua.value_offset(0), 0); + let long_child = ua.child(1).as_any().downcast_ref::().unwrap(); + assert_eq!(long_child.len(), 1); + assert_eq!(long_child.value(0), 6); + let str_child = ua.child(0).as_any().downcast_ref::().unwrap(); + assert_eq!(str_child.len(), 0, "string branch must be empty"); + } + + #[test] + fn test_union_resolution_writer_union_reader_union_unmapped_branch_errors() { + let ws = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Boolean)), + ]); + let rs = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + ]); + let field = AvroField::resolve_from_writer_and_reader(&ws, &rs, false, false).unwrap(); + let mut dec = Decoder::try_new(field.data_type()).unwrap(); + let mut data = encode_avro_long(1); + data.push(1); + let mut cur = AvroCursor::new(&data); + let res = dec.decode(&mut cur); + assert!( + res.is_err(), + "expected error for unmapped writer 'boolean' branch" + ); + } + #[test] fn test_schema_resolution_promotion_int_to_long() { let mut dec = decoder_for_promotion(PrimitiveType::Int, PrimitiveType::Long, false); @@ -1977,12 +2989,14 @@ mod tests { vec!["B".to_string(), "C".to_string(), "A".to_string()].into(); let mapping: Arc<[i32]> = Arc::from(vec![2, 0, 1]); let default_index: i32 = -1; - let mut dec = Decoder::EnumResolved { - indices: Vec::with_capacity(DEFAULT_CAPACITY), - symbols: reader_symbols.clone(), - mapping, - default_index, - }; + let mut dec = Decoder::Enum( + Vec::with_capacity(DEFAULT_CAPACITY), + reader_symbols.clone(), + Some(EnumResolution { + mapping, + default_index, + }), + ); let mut data = Vec::new(); data.extend_from_slice(&encode_avro_int(0)); data.extend_from_slice(&encode_avro_int(1)); @@ -2013,12 +3027,14 @@ mod tests { let reader_symbols: Arc<[String]> = vec!["A".to_string(), "B".to_string()].into(); let default_index: i32 = 1; let mapping: Arc<[i32]> = Arc::from(vec![0, 1]); - let mut dec = Decoder::EnumResolved { - indices: Vec::with_capacity(DEFAULT_CAPACITY), - symbols: reader_symbols.clone(), - mapping, - default_index, - }; + let mut dec = Decoder::Enum( + Vec::with_capacity(DEFAULT_CAPACITY), + reader_symbols.clone(), + Some(EnumResolution { + mapping, + default_index, + }), + ); let mut data = Vec::new(); data.extend_from_slice(&encode_avro_int(0)); data.extend_from_slice(&encode_avro_int(1)); @@ -2048,12 +3064,14 @@ mod tests { let reader_symbols: Arc<[String]> = vec!["A".to_string()].into(); let default_index: i32 = -1; // indicates no default at type-level let mapping: Arc<[i32]> = Arc::from(vec![-1]); - let mut dec = Decoder::EnumResolved { - indices: Vec::with_capacity(DEFAULT_CAPACITY), - symbols: reader_symbols, - mapping, - default_index, - }; + let mut dec = Decoder::Enum( + Vec::with_capacity(DEFAULT_CAPACITY), + reader_symbols, + Some(EnumResolution { + mapping, + default_index, + }), + ); let data = encode_avro_int(0); let mut cur = AvroCursor::new(&data); let err = dec @@ -2069,7 +3087,7 @@ mod tests { fn make_record_resolved_decoder( reader_fields: &[(&str, DataType, bool)], writer_to_reader: Vec>, - mut skip_decoders: Vec>, + skip_decoders: Vec>, ) -> Decoder { let mut field_refs: Vec = Vec::with_capacity(reader_fields.len()); let mut encodings: Vec = Vec::with_capacity(reader_fields.len()); @@ -2086,12 +3104,16 @@ mod tests { encodings.push(enc); } let fields: Fields = field_refs.into(); - Decoder::RecordResolved { + Decoder::Record( fields, encodings, - writer_to_reader: Arc::from(writer_to_reader), - skip_decoders, - } + Some(Projector { + writer_to_reader: Arc::from(writer_to_reader), + skip_decoders, + field_defaults: vec![None; reader_fields.len()], + default_injections: Arc::from(Vec::<(usize, AvroLiteral)>::new()), + }), + ) } #[test] @@ -2257,4 +3279,181 @@ mod tests { assert_eq!(id.value(0), 5); assert_eq!(id.value(1), 7); } + + fn make_dense_union_avro( + children: Vec<(Codec, &'static str, DataType)>, + type_ids: Vec, + ) -> AvroDataType { + let mut avro_children: Vec = Vec::with_capacity(children.len()); + let mut fields: Vec = Vec::with_capacity(children.len()); + + for (codec, name, dt) in children.into_iter() { + avro_children.push(AvroDataType::new(codec, Default::default(), None)); + fields.push(arrow_schema::Field::new(name, dt, true)); + } + let union_fields = UnionFields::new(type_ids, fields); + let union_codec = Codec::Union(avro_children.into(), union_fields, UnionMode::Dense); + AvroDataType::new(union_codec, Default::default(), None) + } + + #[test] + fn test_union_dense_two_children_custom_type_ids() { + let union_dt = make_dense_union_avro( + vec![ + (Codec::Int32, "i", DataType::Int32), + (Codec::Utf8, "s", DataType::Utf8), + ], + vec![2, 5], + ); + let mut dec = Decoder::try_new(&union_dt).unwrap(); + let mut r1 = Vec::new(); + r1.extend_from_slice(&encode_avro_long(0)); + r1.extend_from_slice(&encode_avro_int(7)); + let mut r2 = Vec::new(); + r2.extend_from_slice(&encode_avro_long(1)); + r2.extend_from_slice(&encode_avro_bytes(b"x")); + let mut r3 = Vec::new(); + r3.extend_from_slice(&encode_avro_long(0)); + r3.extend_from_slice(&encode_avro_int(-1)); + dec.decode(&mut AvroCursor::new(&r1)).unwrap(); + dec.decode(&mut AvroCursor::new(&r2)).unwrap(); + dec.decode(&mut AvroCursor::new(&r3)).unwrap(); + let array = dec.flush(None).unwrap(); + let ua = array + .as_any() + .downcast_ref::() + .expect("expected UnionArray"); + assert_eq!(ua.len(), 3); + assert_eq!(ua.type_id(0), 2); + assert_eq!(ua.type_id(1), 5); + assert_eq!(ua.type_id(2), 2); + assert_eq!(ua.value_offset(0), 0); + assert_eq!(ua.value_offset(1), 0); + assert_eq!(ua.value_offset(2), 1); + let int_child = ua + .child(2) + .as_any() + .downcast_ref::() + .expect("int child"); + assert_eq!(int_child.len(), 2); + assert_eq!(int_child.value(0), 7); + assert_eq!(int_child.value(1), -1); + let str_child = ua + .child(5) + .as_any() + .downcast_ref::() + .expect("string child"); + assert_eq!(str_child.len(), 1); + assert_eq!(str_child.value(0), "x"); + } + + #[test] + fn test_union_dense_with_null_and_string_children() { + let union_dt = make_dense_union_avro( + vec![ + (Codec::Null, "n", DataType::Null), + (Codec::Utf8, "s", DataType::Utf8), + ], + vec![42, 7], + ); + let mut dec = Decoder::try_new(&union_dt).unwrap(); + let r1 = encode_avro_long(0); + let mut r2 = Vec::new(); + r2.extend_from_slice(&encode_avro_long(1)); + r2.extend_from_slice(&encode_avro_bytes(b"abc")); + let r3 = encode_avro_long(0); + dec.decode(&mut AvroCursor::new(&r1)).unwrap(); + dec.decode(&mut AvroCursor::new(&r2)).unwrap(); + dec.decode(&mut AvroCursor::new(&r3)).unwrap(); + let array = dec.flush(None).unwrap(); + let ua = array + .as_any() + .downcast_ref::() + .expect("expected UnionArray"); + assert_eq!(ua.len(), 3); + assert_eq!(ua.type_id(0), 42); + assert_eq!(ua.type_id(1), 7); + assert_eq!(ua.type_id(2), 42); + assert_eq!(ua.value_offset(0), 0); + assert_eq!(ua.value_offset(1), 0); + assert_eq!(ua.value_offset(2), 1); + let null_child = ua + .child(42) + .as_any() + .downcast_ref::() + .expect("null child"); + assert_eq!(null_child.len(), 2); + let str_child = ua + .child(7) + .as_any() + .downcast_ref::() + .expect("string child"); + assert_eq!(str_child.len(), 1); + assert_eq!(str_child.value(0), "abc"); + } + + #[test] + fn test_union_decode_negative_branch_index_errors() { + let union_dt = make_dense_union_avro( + vec![ + (Codec::Int32, "i", DataType::Int32), + (Codec::Utf8, "s", DataType::Utf8), + ], + vec![0, 1], + ); + let mut dec = Decoder::try_new(&union_dt).unwrap(); + let row = encode_avro_long(-1); // decodes back to -1 + let err = dec + .decode(&mut AvroCursor::new(&row)) + .expect_err("expected error for negative branch index"); + let msg = err.to_string(); + assert!( + msg.contains("Negative union branch index"), + "unexpected error message: {msg}" + ); + } + + #[test] + fn test_union_decode_out_of_range_branch_index_errors() { + let union_dt = make_dense_union_avro( + vec![ + (Codec::Int32, "i", DataType::Int32), + (Codec::Utf8, "s", DataType::Utf8), + ], + vec![10, 11], + ); + let mut dec = Decoder::try_new(&union_dt).unwrap(); + let row = encode_avro_long(2); + let err = dec + .decode(&mut AvroCursor::new(&row)) + .expect_err("expected error for out-of-range branch index"); + let msg = err.to_string(); + assert!( + msg.contains("out of range"), + "unexpected error message: {msg}" + ); + } + + #[test] + fn test_union_sparse_mode_not_supported() { + let children: Vec = vec![ + AvroDataType::new(Codec::Int32, Default::default(), None), + AvroDataType::new(Codec::Utf8, Default::default(), None), + ]; + let uf = UnionFields::new( + vec![1, 3], + vec![ + arrow_schema::Field::new("i", DataType::Int32, true), + arrow_schema::Field::new("s", DataType::Utf8, true), + ], + ); + let codec = Codec::Union(children.into(), uf, UnionMode::Sparse); + let dt = AvroDataType::new(codec, Default::default(), None); + let err = Decoder::try_new(&dt).expect_err("sparse union should not be supported"); + let msg = err.to_string(); + assert!( + msg.contains("Sparse Arrow unions are not yet supported"), + "unexpected error message: {msg}" + ); + } } diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index 511ba280f7ae..6c501a56abe6 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -17,6 +17,7 @@ use arrow_schema::{ ArrowError, DataType, Field as ArrowField, IntervalUnit, Schema as ArrowSchema, TimeUnit, + UnionMode, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Map as JsonMap, Value}; @@ -94,7 +95,7 @@ pub enum TypeName<'a> { /// A primitive type /// /// -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, AsRefStr)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, AsRefStr)] #[serde(rename_all = "camelCase")] #[strum(serialize_all = "lowercase")] pub enum PrimitiveType { @@ -718,7 +719,7 @@ fn quote(s: &str) -> Result { // handling both ways of specifying the name. It prioritizes a namespace // defined within the `name` attribute itself, then the explicit `namespace_attr`, // and finally the `enclosing_ns`. -fn make_full_name( +pub(crate) fn make_full_name( name: &str, namespace_attr: Option<&str>, enclosing_ns: Option<&str>, @@ -955,6 +956,8 @@ fn merge_extras(schema: Value, mut extras: JsonMap) -> Value { Value::Object(map) } Value::Array(mut union) => { + // For unions, we cannot attach attributes to the array itself (per Avro spec). + // As a fallback for extension metadata, attach extras to the first non-null branch object. if let Some(non_null) = union.iter_mut().find(|val| val.as_str() != Some("null")) { let original = std::mem::take(non_null); *non_null = merge_extras(original, extras); @@ -970,13 +973,59 @@ fn merge_extras(schema: Value, mut extras: JsonMap) -> Value { } } +#[inline] +fn is_avro_json_null(v: &Value) -> bool { + matches!(v, Value::String(s) if s == "null") +} + fn wrap_nullable(inner: Value, null_order: Nullability) -> Value { let null = Value::String("null".into()); - let elements = match null_order { - Nullability::NullFirst => vec![null, inner], - Nullability::NullSecond => vec![inner, null], - }; - Value::Array(elements) + match inner { + Value::Array(mut union) => { + union.retain(|v| !is_avro_json_null(v)); + match null_order { + Nullability::NullFirst => { + let mut out = Vec::with_capacity(union.len() + 1); + out.push(null); + out.extend(union); + Value::Array(out) + } + Nullability::NullSecond => { + union.push(null); + Value::Array(union) + } + } + } + other => match null_order { + Nullability::NullFirst => Value::Array(vec![null, other]), + Nullability::NullSecond => Value::Array(vec![other, null]), + }, + } +} + +fn union_branch_signature(branch: &Value) -> Result { + match branch { + Value::String(t) => Ok(format!("P:{t}")), + Value::Object(map) => { + let t = map.get("type").and_then(|v| v.as_str()).ok_or_else(|| { + ArrowError::SchemaError("Union branch object missing string 'type'".into()) + })?; + match t { + "record" | "enum" | "fixed" => { + let name = map.get("name").and_then(|v| v.as_str()).unwrap_or_default(); + Ok(format!("N:{t}:{name}")) + } + "array" | "map" => Ok(format!("C:{t}")), + other => Ok(format!("P:{other}")), + } + } + Value::Array(_) => Err(ArrowError::SchemaError( + "Avro union may not immediately contain another union".into(), + )), + _ => Err(ArrowError::SchemaError( + "Invalid JSON for Avro union branch".into(), + )), + } } fn datatype_to_avro( @@ -1028,6 +1077,10 @@ fn datatype_to_avro( DataType::Float64 => Value::String("double".into()), DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Value::String("string".into()), DataType::Binary | DataType::LargeBinary => Value::String("bytes".into()), + DataType::BinaryView => { + extras.insert("arrowBinaryView".into(), Value::Bool(true)); + Value::String("bytes".into()) + } DataType::FixedSizeBinary(len) => { let is_uuid = metadata .get("logicalType") @@ -1129,6 +1182,24 @@ fn datatype_to_avro( "items": items_schema }) } + DataType::ListView(child) | DataType::LargeListView(child) => { + if matches!(dt, DataType::LargeListView(_)) { + extras.insert("arrowLargeList".into(), Value::Bool(true)); + } + extras.insert("arrowListView".into(), Value::Bool(true)); + let items_schema = process_datatype( + child.data_type(), + child.name(), + child.metadata(), + name_gen, + null_order, + child.is_nullable(), + )?; + json!({ + "type": "array", + "items": items_schema + }) + } DataType::FixedSizeList(child, len) => { extras.insert("arrowFixedSize".into(), json!(len)); let items_schema = process_datatype( @@ -1205,10 +1276,52 @@ fn datatype_to_avro( null_order, false, )?, - DataType::Union(_, _) => { - return Err(ArrowError::NotYetImplemented( - "Arrow Union to Avro Union not yet supported".into(), - )) + DataType::Union(fields, mode) => { + let mut branches: Vec = Vec::with_capacity(fields.len()); + let mut type_ids: Vec = Vec::with_capacity(fields.len()); + for (type_id, field_ref) in fields.iter() { + // NOTE: `process_datatype` would wrap nullability; force is_nullable=false here. + let (branch_schema, _branch_extras) = datatype_to_avro( + field_ref.data_type(), + field_ref.name(), + field_ref.metadata(), + name_gen, + null_order, + )?; + // Avro unions cannot immediately contain another union + if matches!(branch_schema, Value::Array(_)) { + return Err(ArrowError::SchemaError( + "Avro union may not immediately contain another union".into(), + )); + } + branches.push(branch_schema); + type_ids.push(type_id as i32); + } + let mut seen: HashSet = HashSet::with_capacity(branches.len()); + for b in &branches { + let sig = union_branch_signature(b)?; + if !seen.insert(sig) { + return Err(ArrowError::SchemaError( + "Avro union contains duplicate branch types (disallowed by spec)".into(), + )); + } + } + extras.insert( + "arrowUnionMode".into(), + Value::String( + match mode { + UnionMode::Sparse => "sparse", + UnionMode::Dense => "dense", + } + .to_string(), + ), + ); + extras.insert( + "arrowUnionTypeIds".into(), + Value::Array(type_ids.into_iter().map(|id| json!(id)).collect()), + ); + + Value::Array(branches) } other => { return Err(ArrowError::NotYetImplemented(format!( @@ -1281,7 +1394,7 @@ fn arrow_field_to_avro( mod tests { use super::*; use crate::codec::{AvroDataType, AvroField}; - use arrow_schema::{DataType, Fields, SchemaBuilder, TimeUnit}; + use arrow_schema::{DataType, Fields, SchemaBuilder, TimeUnit, UnionFields}; use serde_json::json; use std::sync::Arc; @@ -1988,17 +2101,47 @@ mod tests { } #[test] - fn test_dense_union_error() { - use arrow_schema::UnionFields; - let uf: UnionFields = vec![(0i8, Arc::new(ArrowField::new("a", DataType::Int32, false)))] - .into_iter() - .collect(); - let union_dt = DataType::Union(uf, arrow_schema::UnionMode::Dense); + fn test_dense_union() { + let uf: UnionFields = vec![ + (2i8, Arc::new(ArrowField::new("a", DataType::Int32, false))), + (7i8, Arc::new(ArrowField::new("b", DataType::Utf8, true))), + ] + .into_iter() + .collect(); + let union_dt = DataType::Union(uf, UnionMode::Dense); let s = single_field_schema(ArrowField::new("u", union_dt, false)); - let err = AvroSchema::try_from(&s).unwrap_err(); - assert!(err - .to_string() - .contains("Arrow Union to Avro Union not yet supported")); + let avro = + AvroSchema::try_from(&s).expect("Arrow Union -> Avro union conversion should succeed"); + let v: serde_json::Value = serde_json::from_str(&avro.json_string).unwrap(); + let fields = v + .get("fields") + .and_then(|x| x.as_array()) + .expect("fields array"); + let u_field = fields + .iter() + .find(|f| f.get("name").and_then(|n| n.as_str()) == Some("u")) + .expect("field 'u'"); + let union = u_field.get("type").expect("u.type"); + let arr = union.as_array().expect("u.type must be Avro union array"); + assert_eq!(arr.len(), 2, "expected two union branches"); + let first = &arr[0]; + let obj = first + .as_object() + .expect("first branch should be an object with metadata"); + assert_eq!(obj.get("type").and_then(|t| t.as_str()), Some("int")); + assert_eq!( + obj.get("arrowUnionMode").and_then(|m| m.as_str()), + Some("dense") + ); + let type_ids: Vec = obj + .get("arrowUnionTypeIds") + .and_then(|a| a.as_array()) + .expect("arrowUnionTypeIds array") + .iter() + .map(|n| n.as_i64().expect("i64")) + .collect(); + assert_eq!(type_ids, vec![2, 7], "type id ordering should be preserved"); + assert_eq!(arr[1], Value::String("string".into())); } #[test] diff --git a/arrow-avro/test/data/README.md b/arrow-avro/test/data/README.md index 51416c8416d4..8ddaaff3324a 100644 --- a/arrow-avro/test/data/README.md +++ b/arrow-avro/test/data/README.md @@ -20,12 +20,12 @@ # Avro test files for `arrow-avro` This directory contains small Avro Object Container Files (OCF) used by -`arrow-avro` tests to validate the `Reader` implementation. These files are generated from +`arrow-avro` tests to validate the `Reader` implementation. These files are generated from a set of python scripts and will gradually be removed as they are merged into `arrow-testing`. ## Decimal Files -This directory contains OCF files used to exercise decoding of Avro’s `decimal` logical type +This directory contains OCF files used to exercise decoding of Avro’s `decimal` logical type across both `bytes` and `fixed` encodings, and to cover Arrow decimal widths ranging from `Decimal32` up through `Decimal256`. The files were generated from a script (see **How these files were created** below). @@ -141,7 +141,62 @@ Options: * --scale (default 10) — the decimal scale used for the 256 files * --no-verify — skip reading the files back for printed verification +## Union File + +**Purpose:** Exercise a wide variety of Avro **union** shapes (including nullable unions, unions of ambiguous scalar types, unions of named types, and unions inside arrays, maps, and nested records) to validate `arrow-avro` union decoding and schema‑resolution paths. + +**Format:** Avro Object Container File (OCF) written by `fastavro.writer` with embedded writer schema. + +**Record count:** four rows. Each row selects different branches across the unions to ensure coverage (i.e., toggling between bytes vs. string, fixed vs. duration vs. decimal, enum vs. record alternatives, etc.). + +**How this file was created:** + +1. Script: [`create_avro_union_file.py`](https://gist.github.com/jecsand838/f4bf85ad597ab34575219df515156444) + Runs with Python 3 and uses **fastavro** to emit `union_fields.avro` in the working directory. +2. Quick reproduce: + ```bash + pip install fastavro + python3 create_avro_union_file.py + # Outputs: ./union_fields.avro + ``` + +> Note: Avro OCF files include a *sync marker*; `fastavro.writer` generates a random one if not provided, so byte‑for‑byte output may vary between runs even with the same data. This does not affect the embedded schema or logical content. + +**Writer schema (overview):** The record is named `UnionTypesRecord` and defines the following fields: + +| Field | Union branches / details | +|-----------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `nullable_int_nullfirst` | `["null","int"]` (tests null‑first ordering) | +| `nullable_string_nullsecond` | `["string","null"]` (tests null‑second ordering; in Avro, a union field’s default must match the *first* branch) | +| `union_prim` | `["boolean","int","long","float","double","bytes","string"]` | +| `union_bytes_vs_string` | `["bytes","string"]` (ambiguous scalar union; script uses fastavro’s tuple notation to disambiguate) | +| `union_fixed_dur_decfix` | `["Fx8","Dur12","DecFix16"]` where:
• `Fx8` = `fixed`(size=8)
• `Dur12` = `fixed`(size=12, `logicalType`=`duration`)
• `DecFix16` = `fixed`(size=16, `logicalType`=`decimal`, precision=10, scale=2)
**Notes:** Avro `duration` is a `fixed[12]` storing **months, days, millis** as three **little‑endian** 32‑bit integers; Avro `decimal` on `bytes`/`fixed` uses **two’s‑complement big‑endian** encoding of the unscaled integer. | +| `union_enum_records_array_map` | `[ColorU, RecA, RecB, array, map]` where:
• `ColorU` = `enum` {`RED`,`GREEN`,`BLUE`}
• `RecA` = `record` {`a:int`, `b:string`}
• `RecB` = `record` {`x:long`, `y:bytes`} | +| `union_date_or_fixed4` | `[int (logicalType=`date`), Fx4]` where `Fx4` = `fixed`(size=4) | +| `union_time_millis_or_enum` | `[int (logicalType=`time-millis`), OnOff]` where `OnOff` = `enum` {`ON`,`OFF`} | +| `union_time_micros_or_string` | `[long (logicalType=`time-micros`), string]` | +| `union_ts_millis_utc_or_array` | `[long (logicalType=`timestamp-millis`), array]` | +| `union_ts_micros_local_or_bytes` | `[long (logicalType=`local-timestamp-micros`), bytes]` | +| `union_uuid_or_fixed10` | `[string (logicalType=`uuid`), Fx10]` where `Fx10` = `fixed`(size=10) | +| `union_dec_bytes_or_dec_fixed` | `[bytes (decimal p=10 s=2), DecFix20]` where `DecFix20` = `fixed`(size=20, decimal p=20 s=4) — decimal encoding is big‑endian two’s‑complement. | +| `union_null_bytes_string` | `["null","bytes","string"]` | +| `array_of_union` | `array<["long","string"]>` | +| `map_of_union` | `map<["null","double"]>` | +| `record_with_union_field` | `HasUnion` = `record` {`id:int`, `u:["int","string"]`} | +| `union_ts_micros_utc_or_map` | `[long (logicalType=`timestamp-micros`), map]` | +| `union_ts_millis_local_or_string` | `[long (logicalType=`local-timestamp-millis`), string]` | +| `union_bool_or_string` | `["boolean","string"]` | + +**Implementation notes (generation):** + +* The script uses **fastavro’s tuple notation** `(branch_name, value)` to select branches in ambiguous unions (e.g., bytes vs. string, multiple named records). See *“Using the tuple notation to specify which branch of a union to take”* in the fastavro docs. +* Decimal values are pre‑encoded to the required **big‑endian two’s‑complement** byte sequence before writing (for both `bytes` and `fixed` decimal logical types). +* The `duration` logical type payloads are 12‑byte triples: **months / days / milliseconds**, little‑endian each. + +**Source / Repro script:** +`create_avro_union_file.py` (Gist): contains the full writer schema, record builders covering four rows, and the `fastavro.writer` call which emits `union_fields.avro`. + ## Other Files -This directory contains other small OCF files used by `arrow-avro` tests. Details on these will be added in +This directory contains other small OCF files used by `arrow-avro` tests. Details on these will be added in follow-up PRs. \ No newline at end of file diff --git a/arrow-avro/test/data/union_fields.avro b/arrow-avro/test/data/union_fields.avro new file mode 100644 index 0000000000000000000000000000000000000000..e0ffb82bd412704065788cf908268e81de4511fc GIT binary patch literal 3430 zcma)8O=u)l5PtJ~4wER0ig;1jv9G(MGd+{ZWG`m3$s`BaKu8V}LeoF5lXm-O?4Rt+ zF2SI<;7LIVg5Y{kkUe`?y|{SrtTz!A$;I2U7ZDMRRj+@hXS%aVLnhttRekkU)z|fU zU3;^_o>%vM?-kuMxL$K3+kQbj2l^g&)Z2h6iRY$CBxt8Z{ z!gjEc?U@$0jR5SoZ)bLhs%kb@EH_l}!L=0@aBSNCez_*-Vxz;7t;# zEP}96_5)?iCL?$rf72RrLop&>fo8x1hD!m)jSyD2mG+@O>Kk8S4WDY4lrD?Drk#4X;H3uw&!mNY9mMUY@tbMV`UkaRyS5w zuH((p>su>{Skj6Wbq?yjipFKyR&gVBMfFaISIh4vqDV4LeM-YEofa-HiCC=@@kaT0 zvErt}i7KL%6>w53de7ip@v-Xd_@SYlZy`nAO*y_K*(v8 zf#(?t>kkED5G%V=Dw5E0xZ+s0Z3W09)j<;bOcv+ZT~Vj4kCPwPTO z>-|26s7%2LIoJ!^CE?aq(#POht^8T*tDSyzCqS)xl2M#==1o<#xBlS5iLl zDW>sY4sjZYH!zt7ryiZ>m<8#3H*Uyu`I2*FvavZPVm@cHpeU6w@+owR!vG=Anu++0 zhi_mppp^Uk%T&2@K-h=JZyfz{_luu;&)vUt{hNi_uQftg&|*g){e1V~Dyfo_x5?F| zYs)MB)ipxMqpv7A`uMX~=>@*eUHC*Lhd+MsZk19 z*S|;X1X)r#D>9(f3Q(qizt-%@U0rW>8-~o=jn3S{LPP7QMng5{OtrhvoYM^rY*{Vj z!;lh47J>m(&A<|g#Gzeyu?O(@G~8T-zdxSAmbq yvIHk_RW82u?~KG~#3;Q5E0hf^>fU+r(LXQGERq_nnO@{WCN5&q>WBx^0{S244>3jn literal 0 HcmV?d00001 From e08322bc041fe9e6f032c6565d33a5161ca10ffb Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sat, 20 Sep 2025 16:39:44 -0500 Subject: [PATCH 02/22] Update arrow-avro/src/reader/record.rs Co-authored-by: Ryan Johnson --- arrow-avro/src/reader/record.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index c87b461e844a..93d19922ea5e 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -112,13 +112,10 @@ macro_rules! get_writer_union_action { ))); } let idx = branch as usize; - let dispatch = match $union_resolution.dispatch.as_deref() { - Some(d) => d, - None => { - return Err(ArrowError::SchemaError( - "dispatch table missing for writer=union".to_string(), - )); - } + let Some(dispatch) = $union_resolution.dispatch.as_deref() else { + return Err(ArrowError::SchemaError( + "dispatch table missing for writer=union".to_string(), + )); }; (idx, *dispatch.get(idx).unwrap_or(&BranchDispatch::NoMatch)) }}; From 1739751b311d80808f5f7404acc15d09b1904d2c Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sat, 20 Sep 2025 17:40:13 -0500 Subject: [PATCH 03/22] Update arrow-avro/src/reader/record.rs Co-authored-by: Ryan Johnson --- arrow-avro/src/reader/record.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 93d19922ea5e..dc6469c63e3a 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -321,8 +321,7 @@ impl UnionResolutionBuilder { "UnionResolutionBuilder for reader union requires fields".to_string(), ) })?; - let reader_type_codes: Vec = - fields.iter().map(|(tid, _)| tid).collect::>(); + let reader_type_codes: Vec = fields.iter().map(|(tid, _)| tid).collect(); let dispatch: Vec = info .writer_to_reader .iter() From 60240292fd42b25799be5d2c37cb81ee968eab28 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sat, 20 Sep 2025 17:45:19 -0500 Subject: [PATCH 04/22] Update arrow-avro/src/reader/record.rs Co-authored-by: Ryan Johnson --- arrow-avro/src/reader/record.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index dc6469c63e3a..bcafa680ae1f 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -348,11 +348,8 @@ impl UnionResolutionBuilder { })?; let reader_type_codes: Vec = fields.iter().map(|(tid, _)| tid).collect::>(); - let (target_reader_index, promotion) = - match info.writer_to_reader.first().and_then(|x| *x) { - Some(pair) => pair, - None => { - return Err(ArrowError::SchemaError( + let Some(Some((target_reader_index, promotion))) = info.writer_to_reader.first() else { + return Err(ArrowError::SchemaError( "Writer schema does not match any reader union branch".to_string(), )) } From fdeb875a6c752e2ff934ed2244a74e3933a81bfb Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sat, 20 Sep 2025 17:50:50 -0500 Subject: [PATCH 05/22] Update arrow-avro/src/reader/record.rs Co-authored-by: Ryan Johnson --- arrow-avro/src/reader/record.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index bcafa680ae1f..d68de58c2817 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -466,12 +466,10 @@ impl Decoder { Vec::new(), Some(union_resolution), ); - return Ok(match data_type.nullability() { - Some(n) => { - Self::Nullable(n, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(base)) - } - None => base, - }); + if let Some(n) = match data_type.nullability() { + base = Self::Nullable(n, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(base)); + } + return Ok(base); } } Self::try_new_internal(data_type) From 79d113c9f63ee35b7c7e9796778cc46aeb06eabd Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sat, 20 Sep 2025 18:07:26 -0500 Subject: [PATCH 06/22] Update arrow-avro/src/reader/record.rs Co-authored-by: Ryan Johnson --- arrow-avro/src/reader/record.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index d68de58c2817..459cfe878d0b 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -1303,7 +1303,7 @@ impl Decoder { Promotion::Direct => self.decode(buf), Promotion::IntToLong => match self { Self::Int64(v) => { - v.push(buf.get_int()? as i64); + v.push(buf.get_int()?.into()); Ok(()) } _ => Err(ArrowError::ParseError( From e7934457af51fee9d6bae0acf779bd86e64476c3 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sat, 20 Sep 2025 18:07:45 -0500 Subject: [PATCH 07/22] Update arrow-avro/src/reader/record.rs Co-authored-by: Ryan Johnson --- arrow-avro/src/reader/record.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 459cfe878d0b..312bae6465de 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -642,10 +642,7 @@ impl Decoder { "Sparse Arrow unions are not yet supported".to_string(), )); } - let mut decoders = Vec::with_capacity(encodings.len()); - for c in encodings.iter() { - decoders.push(Self::try_new_internal(c)?); - } + let decoders = encodings.iter().map(Self::try_new_internal); let union_resolution = match data_type.resolution.as_ref() { Some(ResolutionInfo::Union(info)) if info.reader_is_union => Some( UnionResolutionBuilder::new() From d6364515be2e96a94f9276fa3544db66ac2ef82a Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sat, 20 Sep 2025 18:14:04 -0500 Subject: [PATCH 08/22] Update arrow-avro/src/reader/record.rs Co-authored-by: Ryan Johnson --- arrow-avro/src/reader/record.rs | 38 ++++++++++++++++----------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 312bae6465de..9d806f29a8e2 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -348,18 +348,18 @@ impl UnionResolutionBuilder { })?; let reader_type_codes: Vec = fields.iter().map(|(tid, _)| tid).collect::>(); - let Some(Some((target_reader_index, promotion))) = info.writer_to_reader.first() else { + let Some(Some((target_reader_index, promotion))) = info.writer_to_reader.first() + else { return Err(ArrowError::SchemaError( - "Writer schema does not match any reader union branch".to_string(), - )) - } - }; + "Writer schema does not match any reader union branch".to_string(), + )); + }; Ok(UnionResolution { dispatch: None, kind: UnionResolvedKind::FromSingle { reader_type_codes: Arc::from(reader_type_codes), - target_reader_index, - promotion, + target_reader_index: *target_reader_index, + promotion: *promotion, }, }) } @@ -458,7 +458,7 @@ impl Decoder { if let UnionResolvedKind::ToSingle { target: t } = &mut union_resolution.kind { *t = target; } - let base = Self::Union( + let mut base = Self::Union( UnionFields::empty(), Vec::new(), Vec::new(), @@ -466,8 +466,9 @@ impl Decoder { Vec::new(), Some(union_resolution), ); - if let Some(n) = match data_type.nullability() { - base = Self::Nullable(n, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(base)); + if let Some(n) = data_type.nullability() { + base = + Self::Nullable(n, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(base)); } return Ok(base); } @@ -642,7 +643,10 @@ impl Decoder { "Sparse Arrow unions are not yet supported".to_string(), )); } - let decoders = encodings.iter().map(Self::try_new_internal); + let decoders = encodings + .iter() + .map(Self::try_new_internal) + .collect::, _>>()?; let union_resolution = match data_type.resolution.as_ref() { Some(ResolutionInfo::Union(info)) if info.reader_is_union => Some( UnionResolutionBuilder::new() @@ -723,14 +727,10 @@ impl Decoder { Self::Enum(indices, _, _) => indices.push(0), Self::Duration(builder) => builder.append_null(), Self::Union(fields, type_ids, offsets, encodings, encoding_counts, None) => { - let mut chosen = None; - for (i, ch) in encodings.iter().enumerate() { - if matches!(ch, Decoder::Null(_)) { - chosen = Some(i); - break; - } - } - let idx = chosen.unwrap_or(0); + let idx = encodings + .iter() + .position(|ch| matches!(ch, Decoder::Null(_))) + .unwrap_or(0); let type_id = fields .iter() .nth(idx) From 58890ae80848fbc78158109ed07c56280480ecbf Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sun, 21 Sep 2025 00:32:32 -0500 Subject: [PATCH 09/22] Address PR Comments --- arrow-avro/src/codec.rs | 22 +++ arrow-avro/src/reader/record.rs | 338 ++++++++++++++++---------------- 2 files changed, 191 insertions(+), 169 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 64fc0488e301..63af652b802e 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +use crate::codec::Promotion::{ + BytesToString, Direct, FloatToDouble, IntToDouble, IntToFloat, IntToLong, LongToDouble, + LongToFloat, StringToBytes, +}; use crate::schema::{ make_full_name, Array, Attributes, AvroSchema, ComplexType, Enum, Fixed, Map, Nullability, PrimitiveType, Record, Schema, Type, TypeName, AVRO_ENUM_SYMBOLS_METADATA_KEY, @@ -29,6 +33,8 @@ use arrow_schema::{DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION}; use indexmap::IndexMap; use serde_json::Value; use std::collections::{HashMap, HashSet}; +use std::fmt; +use std::fmt::Display; use std::sync::Arc; use strum_macros::AsRefStr; @@ -117,6 +123,22 @@ pub(crate) enum Promotion { BytesToString, } +impl Display for Promotion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Direct => write!(f, "Direct"), + IntToLong => write!(f, "Int->Long"), + IntToFloat => write!(f, "Int->Float"), + IntToDouble => write!(f, "Int->Double"), + LongToFloat => write!(f, "Long->Float"), + LongToDouble => write!(f, "Long->Double"), + FloatToDouble => write!(f, "Float->Double"), + StringToBytes => write!(f, "String->Bytes"), + BytesToString => write!(f, "Bytes->String"), + } + } +} + /// Information required to resolve a writer union against a reader union (or single type). #[derive(Debug, Clone, PartialEq)] pub struct ResolvedUnion { diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 9d806f29a8e2..32492e131983 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -121,6 +121,22 @@ macro_rules! get_writer_union_action { }}; } +macro_rules! promote_numeric { + ($self:expr, $buf:expr, $variant:ident, $getter:ident, $to:ty, $promotion:expr) => {{ + match $self { + Decoder::$variant(v) => { + let x = $buf.$getter()?; + v.push(x as $to); + Ok(()) + } + _ => Err(ArrowError::ParseError(format!( + "Promotion {} target mismatch", + $promotion + ))), + } + }}; +} + #[derive(Debug)] pub(crate) struct RecordDecoderBuilder<'a> { data_type: &'a AvroDataType, @@ -288,7 +304,6 @@ struct UnionResolutionBuilder { } impl UnionResolutionBuilder { - #[inline] fn new() -> Self { Self { fields: None, @@ -296,15 +311,13 @@ impl UnionResolutionBuilder { } } - #[inline] fn with_fields(mut self, fields: UnionFields) -> Self { self.fields = Some(fields); self } - #[inline] - fn with_resolved_union(mut self, resolved_union: &ResolvedUnion) -> Self { - self.resolved = Some(resolved_union.clone()); + fn with_resolved_union(mut self, resolved_union: ResolvedUnion) -> Self { + self.resolved = Some(resolved_union); self } @@ -314,75 +327,66 @@ impl UnionResolutionBuilder { "UnionResolutionBuilder requires resolved_union to be provided".to_string(), ) })?; - match (info.writer_is_union, info.reader_is_union) { - (true, true) => { - let fields = self.fields.ok_or_else(|| { - ArrowError::InvalidArgumentError( - "UnionResolutionBuilder for reader union requires fields".to_string(), - ) - })?; - let reader_type_codes: Vec = fields.iter().map(|(tid, _)| tid).collect(); - let dispatch: Vec = info - .writer_to_reader - .iter() - .map(|m| match m { - Some((reader_index, promotion)) => BranchDispatch::ToReader { - reader_idx: *reader_index, - promotion: *promotion, - }, - None => BranchDispatch::NoMatch, - }) - .collect(); - Ok(UnionResolution { - dispatch: Some(Arc::from(dispatch)), - kind: UnionResolvedKind::Both { - reader_type_codes: Arc::from(reader_type_codes), + let writer_dispatch: Option> = info.writer_is_union.then(|| { + let dispatches: Vec = info + .writer_to_reader + .iter() + .map(|m| match m { + Some((reader_idx, promotion)) => BranchDispatch::ToReader { + reader_idx: *reader_idx, + promotion: *promotion, }, + None => BranchDispatch::NoMatch, }) - } - (false, true) => { - let fields = self.fields.ok_or_else(|| { - ArrowError::InvalidArgumentError( - "UnionResolutionBuilder for reader union requires fields".to_string(), - ) - })?; - let reader_type_codes: Vec = - fields.iter().map(|(tid, _)| tid).collect::>(); + .collect(); + Arc::from(dispatches) + }); + let reader_type_codes: Option> = if info.reader_is_union { + let fields = self.fields.ok_or_else(|| { + ArrowError::InvalidArgumentError( + "UnionResolutionBuilder for reader union requires fields".to_string(), + ) + })?; + let codes: Vec = fields.iter().map(|(tid, _)| tid).collect(); + Some(Arc::from(codes)) + } else { + None + }; + match (writer_dispatch, reader_type_codes) { + (Some(dispatch), Some(reader_type_codes)) => Ok(UnionResolution { + dispatch: Some(dispatch), + kind: UnionResolvedKind::Both { reader_type_codes }, + }), + (Some(dispatch), None) => Ok(UnionResolution { + dispatch: Some(dispatch), + kind: UnionResolvedKind::ToSingle { + target: Box::new(Decoder::Null(0)), + }, + }), + (None, Some(reader_type_codes)) => { + // writer_is_union == false in this arm, so writer_to_reader must be a singleton: + // [ Some((reader_idx, promotion)) ] if the single writer type matches any + // branch of the reader union; [ None ] otherwise. + debug_assert!( + !info.writer_is_union && info.writer_to_reader.len() == 1, + "internal invariant: expected a single writer to reader mapping when writer_is_union=false" + ); let Some(Some((target_reader_index, promotion))) = info.writer_to_reader.first() else { return Err(ArrowError::SchemaError( - "Writer schema does not match any reader union branch".to_string(), + "Writer type does not match any branch of the reader union".to_string(), )); }; Ok(UnionResolution { dispatch: None, kind: UnionResolvedKind::FromSingle { - reader_type_codes: Arc::from(reader_type_codes), + reader_type_codes, target_reader_index: *target_reader_index, promotion: *promotion, }, }) } - (true, false) => { - let dispatch: Vec = info - .writer_to_reader - .iter() - .map(|m| match m { - Some((reader_index, promotion)) => BranchDispatch::ToReader { - reader_idx: *reader_index, - promotion: *promotion, - }, - None => BranchDispatch::NoMatch, - }) - .collect(); - Ok(UnionResolution { - dispatch: Some(Arc::from(dispatch)), - kind: UnionResolvedKind::ToSingle { - target: Box::new(Decoder::Null(0)), - }, - }) - } - (false, false) => Err(ArrowError::InvalidArgumentError( + (None, None) => Err(ArrowError::InvalidArgumentError( "UnionResolutionBuilder used for non-union case".to_string(), )), } @@ -453,7 +457,7 @@ impl Decoder { clone.resolution = None; let target = Box::new(Self::try_new_internal(&clone)?); let mut union_resolution = UnionResolutionBuilder::new() - .with_resolved_union(info) + .with_resolved_union(info.clone()) .build()?; if let UnionResolvedKind::ToSingle { target: t } = &mut union_resolution.kind { *t = target; @@ -647,11 +651,23 @@ impl Decoder { .iter() .map(Self::try_new_internal) .collect::, _>>()?; + let fields_len = fields.iter().count(); + debug_assert_eq!( + fields_len, + decoders.len(), + "Union fields and decoders must align" + ); + if fields_len != decoders.len() { + return Err(ArrowError::SchemaError(format!( + "UnionFields/encodings length mismatch at construction: fields_len={fields_len}, encodings_len={}", + decoders.len() + ))); + } let union_resolution = match data_type.resolution.as_ref() { Some(ResolutionInfo::Union(info)) if info.reader_is_union => Some( UnionResolutionBuilder::new() .with_fields(fields.clone()) - .with_resolved_union(info) + .with_resolved_union(info.clone()) .build()?, ), _ => None, @@ -683,7 +699,7 @@ impl Decoder { } /// Append a null record - fn append_null(&mut self) { + fn append_null(&mut self) -> Result<(), ArrowError> { match self { Self::Null(count) => *count += 1, Self::Boolean(b) => b.append(false), @@ -711,7 +727,11 @@ impl Decoder { Self::Array(_, offsets, _e) => { offsets.push_length(0); } - Self::Record(_, e, _) => e.iter_mut().for_each(|e| e.append_null()), + Self::Record(_, e, _) => { + for encoding in e.iter_mut() { + encoding.append_null(); + } + } Self::Map(_, _koff, moff, _, _) => { moff.push_length(0); } @@ -727,16 +747,16 @@ impl Decoder { Self::Enum(indices, _, _) => indices.push(0), Self::Duration(builder) => builder.append_null(), Self::Union(fields, type_ids, offsets, encodings, encoding_counts, None) => { + if encodings.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "Cannot append null to empty union".to_string(), + )); + } let idx = encodings .iter() .position(|ch| matches!(ch, Decoder::Null(_))) .unwrap_or(0); - let type_id = fields - .iter() - .nth(idx) - .map(|(type_id, _)| type_id) - .unwrap_or_else(|| i8::try_from(idx).unwrap_or(0)); - type_ids.push(type_id); + type_ids.push(type_id_at(fields, idx, encodings.len())?); offsets.push(encoding_counts[idx]); encodings[idx].append_null(); encoding_counts[idx] += 1; @@ -758,12 +778,7 @@ impl Decoder { } } let idx = chosen.unwrap_or(0); - let type_id = fields - .iter() - .nth(idx) - .map(|(type_id, _)| type_id) - .unwrap_or_else(|| i8::try_from(idx).unwrap_or(0)); - type_ids.push(type_id); + type_ids.push(type_id_at(fields, idx, encodings.len())?); offsets.push(encoding_counts[idx]); encodings[idx].append_null(); encoding_counts[idx] += 1; @@ -775,12 +790,7 @@ impl Decoder { target_reader_index, .. } => { - let type_id = fields - .iter() - .nth(*target_reader_index) - .map(|(type_id, _)| type_id) - .unwrap_or(0); - type_ids.push(type_id); + type_ids.push(type_id_at(fields, *target_reader_index, encodings.len())?); offsets.push(encoding_counts[*target_reader_index]); encodings[*target_reader_index].append_null(); encoding_counts[*target_reader_index] += 1; @@ -791,6 +801,7 @@ impl Decoder { inner.append_null(); } } + Ok(()) } /// Append a single default literal into the decoder's buffers @@ -799,8 +810,7 @@ impl Decoder { Self::Nullable(_, nb, inner) => { if matches!(lit, AvroLiteral::Null) { nb.append(false); - inner.append_null(); - Ok(()) + inner.append_null() } else { nb.append(true); inner.append_default(lit) @@ -1006,12 +1016,7 @@ impl Decoder { "Union default cannot be applied to empty union".to_string(), )); } - let type_id = fields - .iter() - .nth(0) - .map(|(type_id, _)| type_id) - .unwrap_or(0_i8); - type_ids.push(type_id); + type_ids.push(type_id_at(fields, 0, encodings.len())?); offsets.push(encoding_counts[0]); encodings[0].append_default(lit)?; encoding_counts[0] += 1; @@ -1031,12 +1036,7 @@ impl Decoder { "Union default cannot be applied to empty union".to_string(), )); } - let type_id = fields - .iter() - .nth(0) - .map(|(type_id, _)| type_id) - .unwrap_or(0_i8); - type_ids.push(type_id); + type_ids.push(type_id_at(fields, 0, encodings.len())?); offsets.push(encoding_counts[0]); encodings[0].append_default(lit)?; encoding_counts[0] += 1; @@ -1047,12 +1047,7 @@ impl Decoder { target_reader_index, .. } => { - let type_id = fields - .iter() - .nth(*target_reader_index) - .map(|(type_id, _)| type_id) - .unwrap_or(0_i8); - type_ids.push(type_id); + type_ids.push(type_id_at(fields, *target_reader_index, encodings.len())?); offsets.push(encoding_counts[*target_reader_index]); encodings[*target_reader_index].append_default(lit)?; encoding_counts[*target_reader_index] += 1; @@ -1207,12 +1202,7 @@ impl Decoder { encodings.len() ))); } - let type_id = fields - .iter() - .nth(idx) - .map(|(type_id, _)| type_id) - .unwrap_or_else(|| i8::try_from(idx).unwrap_or(0)); - type_ids.push(type_id); + type_ids.push(type_id_at(fields, idx, encodings.len())?); offsets.push(encoding_counts[idx]); encodings[idx].decode(buf)?; encoding_counts[idx] += 1; @@ -1298,60 +1288,18 @@ impl Decoder { ) -> Result<(), ArrowError> { match promotion { Promotion::Direct => self.decode(buf), - Promotion::IntToLong => match self { - Self::Int64(v) => { - v.push(buf.get_int()?.into()); - Ok(()) - } - _ => Err(ArrowError::ParseError( - "Promotion Int->Long target mismatch".into(), - )), - }, - Promotion::IntToFloat => match self { - Self::Float32(v) => { - v.push(buf.get_int()? as f32); - Ok(()) - } - _ => Err(ArrowError::ParseError( - "Promotion Int->Float target mismatch".into(), - )), - }, - Promotion::IntToDouble => match self { - Self::Float64(v) => { - v.push(buf.get_int()? as f64); - Ok(()) - } - _ => Err(ArrowError::ParseError( - "Promotion Int->Double target mismatch".into(), - )), - }, - Promotion::LongToFloat => match self { - Self::Float32(v) => { - v.push(buf.get_long()? as f32); - Ok(()) - } - _ => Err(ArrowError::ParseError( - "Promotion Long->Float target mismatch".into(), - )), - }, - Promotion::LongToDouble => match self { - Self::Float64(v) => { - v.push(buf.get_long()? as f64); - Ok(()) - } - _ => Err(ArrowError::ParseError( - "Promotion Long->Double target mismatch".into(), - )), - }, - Promotion::FloatToDouble => match self { - Self::Float64(v) => { - v.push(buf.get_float()? as f64); - Ok(()) - } - _ => Err(ArrowError::ParseError( - "Promotion Float->Double target mismatch".into(), - )), - }, + Promotion::IntToLong => promote_numeric!(self, buf, Int64, get_int, i64, promotion), + Promotion::IntToFloat => promote_numeric!(self, buf, Float32, get_int, f32, promotion), + Promotion::IntToDouble => promote_numeric!(self, buf, Float64, get_int, f64, promotion), + Promotion::LongToFloat => { + promote_numeric!(self, buf, Float32, get_long, f32, promotion) + } + Promotion::LongToDouble => { + promote_numeric!(self, buf, Float64, get_long, f64, promotion) + } + Promotion::FloatToDouble => { + promote_numeric!(self, buf, Float64, get_float, f64, promotion) + } Promotion::StringToBytes => match self { Self::Binary(offsets, values) | Self::StringToBytes(offsets, values) => { let data = buf.get_bytes()?; @@ -1359,9 +1307,9 @@ impl Decoder { values.extend_from_slice(data); Ok(()) } - _ => Err(ArrowError::ParseError( - "Promotion String->Bytes target mismatch".into(), - )), + _ => Err(ArrowError::ParseError(format!( + "Promotion {promotion} target mismatch", + ))), }, Promotion::BytesToString => match self { Self::String(offsets, values) @@ -1372,9 +1320,9 @@ impl Decoder { values.extend_from_slice(data); Ok(()) } - _ => Err(ArrowError::ParseError( - "Promotion Bytes->String target mismatch".into(), - )), + _ => Err(ArrowError::ParseError(format!( + "Promotion {promotion} target mismatch", + ))), }, } } @@ -1565,6 +1513,20 @@ fn flush_dict( .map(|arr| Arc::new(arr) as ArrayRef) } +#[inline] +fn type_id_at(fields: &UnionFields, idx: usize, encodings_len: usize) -> Result { + fields + .iter() + .nth(idx) + .map(|(tid, _)| tid) + .ok_or_else(|| { + ArrowError::SchemaError(format!( + "UnionFields/encodings length mismatch: child_idx={idx}, fields_len={}, encodings_len={encodings_len}", + fields.iter().count() + )) + }) +} + #[inline] fn read_blocks( buf: &mut AvroCursor, @@ -1784,8 +1746,7 @@ impl Projector { if let Some(default_literal) = self.field_defaults[index].as_ref() { decoder.append_default(default_literal) } else { - decoder.append_null(); - Ok(()) + decoder.append_null() } } @@ -3885,4 +3846,43 @@ mod tests { assert_eq!(id.value(0), 99); assert_eq!(name.value(0), "alice"); } + + #[test] + fn union_type_ids_are_not_child_indexes() { + let encodings: Vec = + vec![avro_from_codec(Codec::Int32), avro_from_codec(Codec::Utf8)]; + let fields: UnionFields = [ + (42_i8, Arc::new(ArrowField::new("a", DataType::Int32, true))), + (7_i8, Arc::new(ArrowField::new("b", DataType::Utf8, true))), + ] + .into_iter() + .collect(); + let dt = avro_from_codec(Codec::Union( + encodings.into(), + fields.clone(), + UnionMode::Dense, + )); + let mut dec = Decoder::try_new(&dt).expect("decoder"); + let mut b1 = encode_avro_long(1); + b1.extend(encode_avro_bytes("hi".as_bytes())); + dec.decode(&mut AvroCursor::new(&b1)).expect("decode b1"); + let mut b0 = encode_avro_long(0); + b0.extend(encode_avro_int(5)); + dec.decode(&mut AvroCursor::new(&b0)).expect("decode b0"); + let arr = dec.flush(None).expect("flush"); + let ua = arr.as_any().downcast_ref::().expect("union"); + assert_eq!(ua.len(), 2); + assert_eq!(ua.type_id(0), 7, "type id must come from UnionFields"); + assert_eq!(ua.type_id(1), 42, "type id must come from UnionFields"); + assert_eq!(ua.value_offset(0), 0); + assert_eq!(ua.value_offset(1), 0); + let utf8_child = ua.child(7).as_any().downcast_ref::().unwrap(); + assert_eq!(utf8_child.len(), 1); + assert_eq!(utf8_child.value(0), "hi"); + let int_child = ua.child(42).as_any().downcast_ref::().unwrap(); + assert_eq!(int_child.len(), 1); + assert_eq!(int_child.value(0), 5); + let type_ids: Vec = fields.iter().map(|(tid, _)| tid).collect(); + assert_eq!(type_ids, vec![42_i8, 7_i8]); + } } From 8a53f1ae5122063f32d4676fb81a3906db68feb3 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sun, 21 Sep 2025 16:12:22 -0500 Subject: [PATCH 10/22] Address PR Comments --- arrow-avro/src/reader/record.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 5fbdd42564a2..042342c5a4af 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -200,6 +200,7 @@ impl RecordDecoder { ) -> Result { match data_type.codec() { Codec::Struct(reader_fields) => { + // Build Arrow schema fields and per-child decoders let mut arrow_fields = Vec::with_capacity(reader_fields.len()); let mut encodings = Vec::with_capacity(reader_fields.len()); for avro_field in reader_fields.iter() { @@ -724,7 +725,7 @@ impl Decoder { Self::Uuid(v) => { v.extend([0; 16]); } - Self::Array(_, offsets, _e) => { + Self::Array(_, offsets, e) => { offsets.push_length(0); } Self::Record(_, e, _) => { @@ -1264,7 +1265,7 @@ impl Decoder { } }, Self::Nullable(order, nb, encoding) => { - let branch = buf.get_long()?; + let branch = buf.read_vlq()?; let is_not_null = match *order { Nullability::NullFirst => branch != 0, Nullability::NullSecond => branch == 0, @@ -1946,7 +1947,7 @@ impl Skipper { } } Self::Nullable(order, inner) => { - let branch = buf.get_long()?; + let branch = buf.read_vlq()?; let is_not_null = match *order { Nullability::NullFirst => branch != 0, Nullability::NullSecond => branch == 0, From 71cb7213bf10222680a872f1cff5530dd559b89f Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sun, 21 Sep 2025 21:36:39 -0500 Subject: [PATCH 11/22] Refactored Union decoding logic into a new `UnionDecoder ` based approach --- arrow-avro/src/codec.rs | 24 +- arrow-avro/src/reader/record.rs | 775 +++++++++++++++----------------- arrow-avro/test/data/README.md | 4 +- 3 files changed, 369 insertions(+), 434 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 63af652b802e..9e2e6ea7bda5 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -15,10 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::codec::Promotion::{ - BytesToString, Direct, FloatToDouble, IntToDouble, IntToFloat, IntToLong, LongToDouble, - LongToFloat, StringToBytes, -}; use crate::schema::{ make_full_name, Array, Attributes, AvroSchema, ComplexType, Enum, Fixed, Map, Nullability, PrimitiveType, Record, Schema, Type, TypeName, AVRO_ENUM_SYMBOLS_METADATA_KEY, @@ -124,17 +120,17 @@ pub(crate) enum Promotion { } impl Display for Promotion { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Direct => write!(f, "Direct"), - IntToLong => write!(f, "Int->Long"), - IntToFloat => write!(f, "Int->Float"), - IntToDouble => write!(f, "Int->Double"), - LongToFloat => write!(f, "Long->Float"), - LongToDouble => write!(f, "Long->Double"), - FloatToDouble => write!(f, "Float->Double"), - StringToBytes => write!(f, "String->Bytes"), - BytesToString => write!(f, "Bytes->String"), + Self::Direct => write!(formatter, "Direct"), + Self::IntToLong => write!(formatter, "Int->Long"), + Self::IntToFloat => write!(formatter, "Int->Float"), + Self::IntToDouble => write!(formatter, "Int->Double"), + Self::LongToFloat => write!(formatter, "Long->Float"), + Self::LongToDouble => write!(formatter, "Long->Double"), + Self::FloatToDouble => write!(formatter, "Float->Double"), + Self::StringToBytes => write!(formatter, "String->Bytes"), + Self::BytesToString => write!(formatter, "Bytes->String"), } } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 042342c5a4af..e5e619ca1075 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -19,12 +19,9 @@ use crate::codec::{ AvroDataType, AvroField, AvroLiteral, Codec, Promotion, ResolutionInfo, ResolvedRecord, ResolvedUnion, }; -use crate::reader::block::{Block, BlockDecoder}; use crate::reader::cursor::AvroCursor; use crate::schema::Nullability; -use arrow_array::builder::{ - Decimal128Builder, Decimal256Builder, IntervalMonthDayNanoBuilder, StringViewBuilder, -}; +use arrow_array::builder::{Decimal128Builder, Decimal256Builder, IntervalMonthDayNanoBuilder}; #[cfg(feature = "small_decimals")] use arrow_array::builder::{Decimal32Builder, Decimal64Builder}; use arrow_array::types::*; @@ -84,43 +81,6 @@ macro_rules! append_decimal_default { }}; } -macro_rules! flush_union { - ($fields:expr, $type_ids:expr, $offsets:expr, $encodings:expr) => {{ - let encoding_arrays = $encodings - .iter_mut() - .map(|d| d.flush(None)) - .collect::, _>>()?; - let type_ids_buf: ScalarBuffer = flush_values($type_ids).into_iter().collect(); - let offsets_buf: ScalarBuffer = flush_values($offsets).into_iter().collect(); - let arr = UnionArray::try_new( - $fields.clone(), - type_ids_buf, - Some(offsets_buf), - encoding_arrays, - ) - .map_err(|e| ArrowError::ParseError(e.to_string()))?; - Arc::new(arr) - }}; -} - -macro_rules! get_writer_union_action { - ($buf:expr, $union_resolution:expr) => {{ - let branch = $buf.get_long()?; - if branch < 0 { - return Err(ArrowError::ParseError(format!( - "Negative union branch index {branch}" - ))); - } - let idx = branch as usize; - let Some(dispatch) = $union_resolution.dispatch.as_deref() else { - return Err(ArrowError::SchemaError( - "dispatch table missing for writer=union".to_string(), - )); - }; - (idx, *dispatch.get(idx).unwrap_or(&BranchDispatch::NoMatch)) - }}; -} - macro_rules! promote_numeric { ($self:expr, $buf:expr, $variant:ident, $getter:ident, $to:ty, $promotion:expr) => {{ match $self { @@ -268,132 +228,6 @@ struct EnumResolution { default_index: i32, } -#[derive(Debug, Clone, Copy)] -enum BranchDispatch { - NoMatch, - ToReader { - reader_idx: usize, - promotion: Promotion, - }, -} - -#[derive(Debug)] -struct UnionResolution { - dispatch: Option>, - kind: UnionResolvedKind, -} - -#[derive(Debug)] -enum UnionResolvedKind { - Both { - reader_type_codes: Arc<[i8]>, - }, - ToSingle { - target: Box, - }, - FromSingle { - reader_type_codes: Arc<[i8]>, - target_reader_index: usize, - promotion: Promotion, - }, -} - -#[derive(Debug, Default)] -struct UnionResolutionBuilder { - fields: Option, - resolved: Option, -} - -impl UnionResolutionBuilder { - fn new() -> Self { - Self { - fields: None, - resolved: None, - } - } - - fn with_fields(mut self, fields: UnionFields) -> Self { - self.fields = Some(fields); - self - } - - fn with_resolved_union(mut self, resolved_union: ResolvedUnion) -> Self { - self.resolved = Some(resolved_union); - self - } - - fn build(self) -> Result { - let info = self.resolved.ok_or_else(|| { - ArrowError::InvalidArgumentError( - "UnionResolutionBuilder requires resolved_union to be provided".to_string(), - ) - })?; - let writer_dispatch: Option> = info.writer_is_union.then(|| { - let dispatches: Vec = info - .writer_to_reader - .iter() - .map(|m| match m { - Some((reader_idx, promotion)) => BranchDispatch::ToReader { - reader_idx: *reader_idx, - promotion: *promotion, - }, - None => BranchDispatch::NoMatch, - }) - .collect(); - Arc::from(dispatches) - }); - let reader_type_codes: Option> = if info.reader_is_union { - let fields = self.fields.ok_or_else(|| { - ArrowError::InvalidArgumentError( - "UnionResolutionBuilder for reader union requires fields".to_string(), - ) - })?; - let codes: Vec = fields.iter().map(|(tid, _)| tid).collect(); - Some(Arc::from(codes)) - } else { - None - }; - match (writer_dispatch, reader_type_codes) { - (Some(dispatch), Some(reader_type_codes)) => Ok(UnionResolution { - dispatch: Some(dispatch), - kind: UnionResolvedKind::Both { reader_type_codes }, - }), - (Some(dispatch), None) => Ok(UnionResolution { - dispatch: Some(dispatch), - kind: UnionResolvedKind::ToSingle { - target: Box::new(Decoder::Null(0)), - }, - }), - (None, Some(reader_type_codes)) => { - // writer_is_union == false in this arm, so writer_to_reader must be a singleton: - // [ Some((reader_idx, promotion)) ] if the single writer type matches any - // branch of the reader union; [ None ] otherwise. - debug_assert!( - !info.writer_is_union && info.writer_to_reader.len() == 1, - "internal invariant: expected a single writer to reader mapping when writer_is_union=false" - ); - let Some(Some((target_reader_index, promotion))) = info.writer_to_reader.first() - else { - return Err(ArrowError::SchemaError( - "Writer type does not match any branch of the reader union".to_string(), - )); - }; - Ok(UnionResolution { - dispatch: None, - kind: UnionResolvedKind::FromSingle { - reader_type_codes, - target_reader_index: *target_reader_index, - promotion: *promotion, - }, - }) - } - (None, None) => Err(ArrowError::InvalidArgumentError( - "UnionResolutionBuilder used for non-union case".to_string(), - )), - } - } -} - #[derive(Debug)] enum Decoder { Null(usize), @@ -439,43 +273,25 @@ enum Decoder { Decimal64(usize, Option, Option, Decimal64Builder), Decimal128(usize, Option, Option, Decimal128Builder), Decimal256(usize, Option, Option, Decimal256Builder), - Union( - UnionFields, - Vec, - Vec, - Vec, - Vec, - Option, - ), + Union(UnionDecoder), Nullable(Nullability, NullBufferBuilder, Box), } impl Decoder { fn try_new(data_type: &AvroDataType) -> Result { + // Extract just the Promotion (if any) to simplify pattern matching if let Some(ResolutionInfo::Union(info)) = data_type.resolution.as_ref() { if info.writer_is_union && !info.reader_is_union { let mut clone = data_type.clone(); - clone.resolution = None; + clone.resolution = None; // Build target base decoder without Union resolution let target = Box::new(Self::try_new_internal(&clone)?); - let mut union_resolution = UnionResolutionBuilder::new() - .with_resolved_union(info.clone()) - .build()?; - if let UnionResolvedKind::ToSingle { target: t } = &mut union_resolution.kind { - *t = target; - } - let mut base = Self::Union( - UnionFields::empty(), - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - Some(union_resolution), + let decoder = Self::Union( + UnionDecoderBuilder::new() + .with_resolved_union(info.clone()) + .with_target(target) + .build()?, ); - if let Some(n) = data_type.nullability() { - base = - Self::Nullable(n, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(base)); - } - return Ok(base); + return Ok(decoder); } } Self::try_new_internal(data_type) @@ -652,35 +468,22 @@ impl Decoder { .iter() .map(Self::try_new_internal) .collect::, _>>()?; - let fields_len = fields.iter().count(); - debug_assert_eq!( - fields_len, - decoders.len(), - "Union fields and decoders must align" - ); - if fields_len != decoders.len() { + if fields.len() != decoders.len() { return Err(ArrowError::SchemaError(format!( - "UnionFields/encodings length mismatch at construction: fields_len={fields_len}, encodings_len={}", + "Union has {} fields but {} decoders", + fields.len(), decoders.len() ))); } - let union_resolution = match data_type.resolution.as_ref() { - Some(ResolutionInfo::Union(info)) if info.reader_is_union => Some( - UnionResolutionBuilder::new() - .with_fields(fields.clone()) - .with_resolved_union(info.clone()) - .build()?, - ), - _ => None, - }; - Self::Union( - fields.clone(), - Vec::with_capacity(DEFAULT_CAPACITY), - Vec::with_capacity(DEFAULT_CAPACITY), - decoders, - vec![0; encodings.len()], - union_resolution, - ) + let mut builder = UnionDecoderBuilder::new() + .with_fields(fields.clone()) + .with_branches(decoders); + if let Some(ResolutionInfo::Union(info)) = data_type.resolution.as_ref() { + if info.reader_is_union { + builder = builder.with_resolved_union(info.clone()); + } + } + Self::Union(builder.build()?) } (Codec::Uuid, _) => Self::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)), (&Codec::Union(_, _, _), _) => { @@ -747,56 +550,7 @@ impl Decoder { Self::Decimal256(_, _, _, builder) => builder.append_value(i256::ZERO), Self::Enum(indices, _, _) => indices.push(0), Self::Duration(builder) => builder.append_null(), - Self::Union(fields, type_ids, offsets, encodings, encoding_counts, None) => { - if encodings.is_empty() { - return Err(ArrowError::InvalidArgumentError( - "Cannot append null to empty union".to_string(), - )); - } - let idx = encodings - .iter() - .position(|ch| matches!(ch, Decoder::Null(_))) - .unwrap_or(0); - type_ids.push(type_id_at(fields, idx, encodings.len())?); - offsets.push(encoding_counts[idx]); - encodings[idx].append_null(); - encoding_counts[idx] += 1; - } - Self::Union( - fields, - type_ids, - offsets, - encodings, - encoding_counts, - Some(union_resolution), - ) => match &mut union_resolution.kind { - UnionResolvedKind::Both { .. } => { - let mut chosen = None; - for (i, ch) in encodings.iter().enumerate() { - if matches!(ch, Decoder::Null(_)) { - chosen = Some(i); - break; - } - } - let idx = chosen.unwrap_or(0); - type_ids.push(type_id_at(fields, idx, encodings.len())?); - offsets.push(encoding_counts[idx]); - encodings[idx].append_null(); - encoding_counts[idx] += 1; - } - UnionResolvedKind::ToSingle { target } => { - target.append_null(); - } - UnionResolvedKind::FromSingle { - target_reader_index, - .. - } => { - type_ids.push(type_id_at(fields, *target_reader_index, encodings.len())?); - offsets.push(encoding_counts[*target_reader_index]); - encodings[*target_reader_index].append_null(); - encoding_counts[*target_reader_index] += 1; - } - }, + Self::Union(u) => u.append_null()?, Self::Nullable(_, null_buffer, inner) => { null_buffer.append(false); inner.append_null(); @@ -1011,50 +765,7 @@ impl Decoder { "Default for enum must be a symbol".to_string(), )), }, - Self::Union(fields, type_ids, offsets, encodings, encoding_counts, None) => { - if encodings.is_empty() { - return Err(ArrowError::InvalidArgumentError( - "Union default cannot be applied to empty union".to_string(), - )); - } - type_ids.push(type_id_at(fields, 0, encodings.len())?); - offsets.push(encoding_counts[0]); - encodings[0].append_default(lit)?; - encoding_counts[0] += 1; - Ok(()) - } - Self::Union( - fields, - type_ids, - offsets, - encodings, - encoding_counts, - Some(union_resolution), - ) => match &mut union_resolution.kind { - UnionResolvedKind::Both { .. } => { - if encodings.is_empty() { - return Err(ArrowError::InvalidArgumentError( - "Union default cannot be applied to empty union".to_string(), - )); - } - type_ids.push(type_id_at(fields, 0, encodings.len())?); - offsets.push(encoding_counts[0]); - encodings[0].append_default(lit)?; - encoding_counts[0] += 1; - Ok(()) - } - UnionResolvedKind::ToSingle { target } => target.append_default(lit), - UnionResolvedKind::FromSingle { - target_reader_index, - .. - } => { - type_ids.push(type_id_at(fields, *target_reader_index, encodings.len())?); - offsets.push(encoding_counts[*target_reader_index]); - encodings[*target_reader_index].append_default(lit)?; - encoding_counts[*target_reader_index] += 1; - Ok(()) - } - }, + Self::Union(u) => u.append_default(lit), Self::Record(field_meta, decoders, projector) => match lit { AvroLiteral::Map(entries) => { for (i, dec) in decoders.iter_mut().enumerate() { @@ -1189,81 +900,7 @@ impl Decoder { let nanos = (millis as i64) * 1_000_000; builder.append_value(IntervalMonthDayNano::new(months as i32, days as i32, nanos)); } - Self::Union(fields, type_ids, offsets, encodings, encoding_counts, None) => { - let branch = buf.get_long()?; - if branch < 0 { - return Err(ArrowError::ParseError(format!( - "Negative union branch index {branch}" - ))); - } - let idx = branch as usize; - if idx >= encodings.len() { - return Err(ArrowError::ParseError(format!( - "Union branch index {idx} out of range ({} branches)", - encodings.len() - ))); - } - type_ids.push(type_id_at(fields, idx, encodings.len())?); - offsets.push(encoding_counts[idx]); - encodings[idx].decode(buf)?; - encoding_counts[idx] += 1; - } - Self::Union( - _, - type_ids, - offsets, - encodings, - encoding_counts, - Some(union_resolution), - ) => match &mut union_resolution.kind { - UnionResolvedKind::Both { - reader_type_codes, .. - } => { - let (idx, action) = get_writer_union_action!(buf, union_resolution); - match action { - BranchDispatch::NoMatch => { - return Err(ArrowError::ParseError(format!( - "Union branch index {idx} not resolvable by reader schema" - ))); - } - BranchDispatch::ToReader { - reader_idx, - promotion, - } => { - let type_id = reader_type_codes[reader_idx]; - type_ids.push(type_id); - offsets.push(encoding_counts[reader_idx]); - encodings[reader_idx].decode_with_promotion(buf, promotion)?; - encoding_counts[reader_idx] += 1; - } - } - } - UnionResolvedKind::ToSingle { target } => { - let (idx, action) = get_writer_union_action!(buf, union_resolution); - match action { - BranchDispatch::NoMatch => { - return Err(ArrowError::ParseError(format!( - "Writer union branch {idx} does not resolve to reader type" - ))); - } - BranchDispatch::ToReader { promotion, .. } => { - target.decode_with_promotion(buf, promotion)?; - } - } - } - UnionResolvedKind::FromSingle { - reader_type_codes, - target_reader_index, - promotion, - .. - } => { - let type_id = reader_type_codes[*target_reader_index]; - type_ids.push(type_id); - offsets.push(encoding_counts[*target_reader_index]); - encodings[*target_reader_index].decode_with_promotion(buf, *promotion)?; - encoding_counts[*target_reader_index] += 1; - } - }, + Self::Union(u) => u.decode(buf)?, Self::Nullable(order, nb, encoding) => { let branch = buf.read_vlq()?; let is_not_null = match *order { @@ -1467,19 +1104,340 @@ impl Decoder { .map_err(|e| ArrowError::ParseError(e.to_string()))?; Arc::new(vals) } - Self::Union(fields, type_ids, offsets, encodings, _, None) => { - flush_union!(fields, type_ids, offsets, encodings) - } - Self::Union(fields, type_ids, offsets, encodings, _, Some(union_resolution)) => { - match &mut union_resolution.kind { - UnionResolvedKind::Both { .. } | UnionResolvedKind::FromSingle { .. } => { - flush_union!(fields, type_ids, offsets, encodings) - } - UnionResolvedKind::ToSingle { target } => target.flush(nulls)?, + Self::Union(u) => u.flush(nulls)?, + }) + } +} + +#[derive(Debug)] +struct DispatchLut { + to_reader: Box<[i16]>, + promotion: Box<[Promotion]>, +} + +impl DispatchLut { + fn from_writer_to_reader(promotion_map: &[Option<(usize, Promotion)>]) -> Self { + let mut to_reader = Vec::with_capacity(promotion_map.len()); + let mut promotion = Vec::with_capacity(promotion_map.len()); + for map in promotion_map { + match *map { + Some((idx, promo)) => { + debug_assert!(idx <= i16::MAX as usize); + to_reader.push(idx as i16); + promotion.push(promo); + } + None => { + to_reader.push(-1); + promotion.push(Promotion::Direct); } } + } + Self { + to_reader: to_reader.into_boxed_slice(), + promotion: promotion.into_boxed_slice(), + } + } + + // Resolve a writer branch index to (reader_idx, promotion) + #[inline] + fn resolve(&self, writer_idx: usize) -> Option<(usize, Promotion)> { + if writer_idx >= self.to_reader.len() { + return None; + } + let reader_index = self.to_reader[writer_idx]; + if reader_index < 0 { + None + } else { + Some((reader_index as usize, self.promotion[writer_idx])) + } + } +} + +#[derive(Debug)] +struct UnionDecoder { + fields: UnionFields, + type_ids: Vec, + offsets: Vec, + branches: Vec, + counts: Vec, + type_id_by_reader_idx: Arc<[i8]>, + null_branch: Option, + default_emit_idx: usize, + null_emit_idx: usize, + plan: UnionReadPlan, +} + +impl Default for UnionDecoder { + fn default() -> Self { + Self { + fields: UnionFields::empty(), + type_ids: Vec::new(), + offsets: Vec::new(), + branches: Vec::new(), + counts: Vec::new(), + type_id_by_reader_idx: Arc::from([]), + null_branch: None, + default_emit_idx: 0, + null_emit_idx: 0, + plan: UnionReadPlan::Passthrough, + } + } +} + +#[derive(Debug)] +enum UnionReadPlan { + ReaderUnion { + lookup_table: DispatchLut, + }, + FromSingle { + reader_idx: usize, + promotion: Promotion, + }, + ToSingle { + target: Box, + lookup_table: DispatchLut, + }, + Passthrough, +} + +impl UnionDecoder { + fn try_new( + fields: UnionFields, + branches: Vec, + resolved: Option, + ) -> Result { + let reader_type_codes: Arc<[i8]> = + Arc::from(fields.iter().map(|(tid, _)| tid).collect::>()); + let null_branch = branches.iter().position(|b| matches!(b, Decoder::Null(_))); + let default_emit_idx = 0; + let null_emit_idx = null_branch.unwrap_or(default_emit_idx); + let plan = Self::plan_from_resolved(resolved)?; + let branch_len = branches.len().max(reader_type_codes.len()); + Ok(Self { + fields, + type_ids: Vec::with_capacity(DEFAULT_CAPACITY), + offsets: Vec::with_capacity(DEFAULT_CAPACITY), + branches, + counts: vec![0; branch_len], + type_id_by_reader_idx: reader_type_codes, + null_branch, + default_emit_idx, + null_emit_idx, + plan, + }) + } + + fn try_new_from_writer_union( + info: ResolvedUnion, + target: Box, + ) -> Result { + // This constructor is only for writer-union to single-type resolution + debug_assert!(info.writer_is_union && !info.reader_is_union); + let lookup_table = DispatchLut::from_writer_to_reader(&info.writer_to_reader); + Ok(Self { + plan: UnionReadPlan::ToSingle { + target, + lookup_table, + }, + ..Self::default() }) } + + fn plan_from_resolved(resolved: Option) -> Result { + match resolved { + None => Ok(UnionReadPlan::Passthrough), + Some(info) => match (info.writer_is_union, info.reader_is_union) { + (true, true) => { + let lookup_table = DispatchLut::from_writer_to_reader(&info.writer_to_reader); + Ok(UnionReadPlan::ReaderUnion { lookup_table }) + } + (false, true) => { + let (reader_idx, promotion) = + info.writer_to_reader.first().and_then(|x| *x).ok_or_else(|| { + ArrowError::SchemaError( + "Writer type does not match any reader union branch".to_string(), + ) + })?; + Ok(UnionReadPlan::FromSingle { + reader_idx, + promotion, + }) + } + (true, false) => Err(ArrowError::InvalidArgumentError( + "UnionDecoder::try_new cannot build writer-union to single; use UnionDecoderBuilder with a target" + .to_string(), + )), + (false, false) => Ok(UnionReadPlan::Passthrough), + }, + } + } + + #[inline] + fn read_tag(buf: &mut AvroCursor<'_>) -> Result { + let tag = buf.get_long()?; + if tag < 0 { + return Err(ArrowError::ParseError(format!( + "Negative union branch index {tag}" + ))); + } + Ok(tag as usize) + } + + #[inline] + fn emit_to(&mut self, reader_idx: usize) -> Result<&mut Decoder, ArrowError> { + if reader_idx >= self.branches.len() { + return Err(ArrowError::ParseError(format!( + "Union branch index {reader_idx} out of range ({} branches)", + self.branches.len() + ))); + } + self.type_ids.push(self.type_id_by_reader_idx[reader_idx]); + self.offsets.push(self.counts[reader_idx]); + self.counts[reader_idx] += 1; + Ok(&mut self.branches[reader_idx]) + } + + #[inline] + fn on_decoder(&mut self, fallback_idx: usize, action: F) -> Result<(), ArrowError> + where + F: FnOnce(&mut Decoder) -> Result<(), ArrowError>, + { + if let UnionReadPlan::ToSingle { target, .. } = &mut self.plan { + return action(target); + } + let reader_idx = match &self.plan { + UnionReadPlan::FromSingle { reader_idx, .. } => *reader_idx, + _ => fallback_idx, + }; + self.emit_to(reader_idx).and_then(action) + } + + fn append_null(&mut self) -> Result<(), ArrowError> { + self.on_decoder(self.null_emit_idx, |decoder| decoder.append_null()) + } + + fn append_default(&mut self, lit: &AvroLiteral) -> Result<(), ArrowError> { + self.on_decoder(self.default_emit_idx, |decoder| decoder.append_default(lit)) + } + + fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { + let (reader_idx, promotion) = match &mut self.plan { + UnionReadPlan::ToSingle { + target, + lookup_table, + } => { + let idx = Self::read_tag(buf)?; + return match lookup_table.resolve(idx) { + Some((_, promotion)) => target.decode_with_promotion(buf, promotion), + None => Err(ArrowError::ParseError(format!( + "Writer union branch {idx} does not resolve to reader type" + ))), + }; + } + UnionReadPlan::Passthrough => (Self::read_tag(buf)?, Promotion::Direct), + UnionReadPlan::ReaderUnion { lookup_table } => { + let idx = Self::read_tag(buf)?; + lookup_table.resolve(idx).ok_or_else(|| { + ArrowError::ParseError(format!( + "Union branch index {idx} not resolvable by reader schema" + )) + })? + } + UnionReadPlan::FromSingle { + reader_idx, + promotion, + } => (*reader_idx, *promotion), + UnionReadPlan::ToSingle { .. } => { + return Err(ArrowError::ParseError( + "Invalid union read plan state".to_string(), + )); + } + }; + let decoder = self.emit_to(reader_idx)?; + decoder.decode_with_promotion(buf, promotion) + } + + fn flush(&mut self, nulls: Option) -> Result { + match &mut self.plan { + UnionReadPlan::ToSingle { target, .. } => target.flush(nulls), + _ => { + debug_assert!( + nulls.is_none(), + "UnionArray does not accept a validity bitmap; \ + nulls should have been materialized as a Null child during decode" + ); + let children = self + .branches + .iter_mut() + .map(|d| d.flush(None)) + .collect::, _>>()?; + let type_ids_buf: ScalarBuffer = + flush_values(&mut self.type_ids).into_iter().collect(); + let offsets_buf: ScalarBuffer = + flush_values(&mut self.offsets).into_iter().collect(); + let arr = UnionArray::try_new( + self.fields.clone(), + type_ids_buf, + Some(offsets_buf), + children, + ) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Ok(Arc::new(arr)) + } + } + } +} + +#[derive(Debug, Default)] +struct UnionDecoderBuilder { + fields: Option, + branches: Option>, + resolved: Option, + target: Option>, +} + +impl UnionDecoderBuilder { + fn new() -> Self { + Self::default() + } + + fn with_fields(mut self, fields: UnionFields) -> Self { + self.fields = Some(fields); + self + } + + fn with_branches(mut self, branches: Vec) -> Self { + self.branches = Some(branches); + self + } + + fn with_resolved_union(mut self, resolved_union: ResolvedUnion) -> Self { + self.resolved = Some(resolved_union); + self + } + + fn with_target(mut self, target: Box) -> Self { + self.target = Some(target); + self + } + + fn build(self) -> Result { + match (self.resolved, self.fields, self.branches, self.target) { + (resolved, Some(fields), Some(branches), _) => { + UnionDecoder::try_new(fields, branches, resolved) + } + (Some(info), None, None, Some(target)) + if info.writer_is_union && !info.reader_is_union => + { + UnionDecoder::try_new_from_writer_union(info, target) + } + _ => Err(ArrowError::InvalidArgumentError( + "UnionDecoderBuilder requires (fields + branches) for reader-unions \ + or (resolved + target) for writer-union to single" + .to_string(), + )), + } + } } #[derive(Debug, Copy, Clone)] @@ -1515,20 +1473,6 @@ fn flush_dict( .map(|arr| Arc::new(arr) as ArrayRef) } -#[inline] -fn type_id_at(fields: &UnionFields, idx: usize, encodings_len: usize) -> Result { - fields - .iter() - .nth(idx) - .map(|(tid, _)| tid) - .ok_or_else(|| { - ArrowError::SchemaError(format!( - "UnionFields/encodings length mismatch: child_idx={idx}, fields_len={}, encodings_len={encodings_len}", - fields.iter().count() - )) - }) -} - #[inline] fn read_blocks( buf: &mut AvroCursor, @@ -1966,11 +1910,7 @@ mod tests { use super::*; use crate::codec::AvroField; use crate::schema::{PrimitiveType, Schema, TypeName}; - use arrow_array::{ - cast::AsArray, Array, Decimal128Array, Decimal256Array, Decimal32Array, DictionaryArray, - FixedSizeBinaryArray, IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, - StructArray, - }; + use arrow_array::cast::AsArray; use indexmap::IndexMap; fn encode_avro_int(value: i32) -> Vec { @@ -3232,12 +3172,11 @@ mod tests { } fn make_dense_union_avro( - children: Vec<(Codec, &'static str, DataType)>, + children: Vec<(Codec, &'_ str, DataType)>, type_ids: Vec, ) -> AvroDataType { let mut avro_children: Vec = Vec::with_capacity(children.len()); let mut fields: Vec = Vec::with_capacity(children.len()); - for (codec, name, dt) in children.into_iter() { avro_children.push(AvroDataType::new(codec, Default::default(), None)); fields.push(arrow_schema::Field::new(name, dt, true)); diff --git a/arrow-avro/test/data/README.md b/arrow-avro/test/data/README.md index 8ddaaff3324a..1d7d8482f924 100644 --- a/arrow-avro/test/data/README.md +++ b/arrow-avro/test/data/README.md @@ -20,12 +20,12 @@ # Avro test files for `arrow-avro` This directory contains small Avro Object Container Files (OCF) used by -`arrow-avro` tests to validate the `Reader` implementation. These files are generated from +`arrow-avro` tests to validate the `Reader` implementation. These files are generated from a set of python scripts and will gradually be removed as they are merged into `arrow-testing`. ## Decimal Files -This directory contains OCF files used to exercise decoding of Avro’s `decimal` logical type +This directory contains OCF files used to exercise decoding of Avro’s `decimal` logical type across both `bytes` and `fixed` encodings, and to cover Arrow decimal widths ranging from `Decimal32` up through `Decimal256`. The files were generated from a script (see **How these files were created** below). From 63ffed27e6f4dd94b26a756983fee46700ca567d Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 22 Sep 2025 10:47:02 -0500 Subject: [PATCH 12/22] Update arrow-avro/src/reader/record.rs Co-authored-by: Ryan Johnson --- arrow-avro/src/reader/record.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index e5e619ca1075..19d4ded0fac2 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -528,7 +528,7 @@ impl Decoder { Self::Uuid(v) => { v.extend([0; 16]); } - Self::Array(_, offsets, e) => { + Self::Array(_, offsets, _) => { offsets.push_length(0); } Self::Record(_, e, _) => { From 34e258f6428359937f51f18cfd6e71ccadbd98f8 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 22 Sep 2025 10:54:08 -0500 Subject: [PATCH 13/22] Update arrow-avro/src/reader/record.rs Co-authored-by: Ryan Johnson --- arrow-avro/src/reader/record.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 19d4ded0fac2..aab00c5bbbff 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -1371,10 +1371,8 @@ impl UnionDecoder { .iter_mut() .map(|d| d.flush(None)) .collect::, _>>()?; - let type_ids_buf: ScalarBuffer = - flush_values(&mut self.type_ids).into_iter().collect(); - let offsets_buf: ScalarBuffer = - flush_values(&mut self.offsets).into_iter().collect(); + let type_ids_buf = flush_values(&mut self.type_ids).into_iter().collect(); + let offsets_buf = flush_values(&mut self.offsets).into_iter().collect(); let arr = UnionArray::try_new( self.fields.clone(), type_ids_buf, From ddde842672c7a6a7e035306868284f866e9c2689 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 22 Sep 2025 12:48:05 -0500 Subject: [PATCH 14/22] Address PR Comments --- arrow-avro/src/reader/record.rs | 243 +++++++++++++++----------------- 1 file changed, 111 insertions(+), 132 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index aab00c5bbbff..4a2fed5faba5 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -35,6 +35,7 @@ use arrow_schema::{ use arrow_schema::{DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION}; use std::cmp::Ordering; use std::sync::Arc; +use strum_macros::AsRefStr; use uuid::Uuid; const DEFAULT_CAPACITY: usize = 1024; @@ -81,22 +82,6 @@ macro_rules! append_decimal_default { }}; } -macro_rules! promote_numeric { - ($self:expr, $buf:expr, $variant:ident, $getter:ident, $to:ty, $promotion:expr) => {{ - match $self { - Decoder::$variant(v) => { - let x = $buf.$getter()?; - v.push(x as $to); - Ok(()) - } - _ => Err(ArrowError::ParseError(format!( - "Promotion {} target mismatch", - $promotion - ))), - } - }}; -} - #[derive(Debug)] pub(crate) struct RecordDecoderBuilder<'a> { data_type: &'a AvroDataType, @@ -228,7 +213,7 @@ struct EnumResolution { default_index: i32, } -#[derive(Debug)] +#[derive(Debug, AsRefStr)] enum Decoder { Null(usize), Boolean(BooleanBufferBuilder), @@ -458,12 +443,7 @@ impl Decoder { Box::new(val_dec), ) } - (Codec::Union(encodings, fields, mode), _) => { - if *mode != UnionMode::Dense { - return Err(ArrowError::NotYetImplemented( - "Sparse Arrow unions are not yet supported".to_string(), - )); - } + (Codec::Union(encodings, fields, mode), _) if *mode == UnionMode::Dense => { let decoders = encodings .iter() .map(Self::try_new_internal) @@ -485,12 +465,12 @@ impl Decoder { } Self::Union(builder.build()?) } - (Codec::Uuid, _) => Self::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)), - (&Codec::Union(_, _, _), _) => { + (Codec::Union(_, _, _), _) => { return Err(ArrowError::NotYetImplemented( - "Union type decoding is not yet supported".to_string(), - )) + "Sparse Arrow unions are not yet supported".to_string(), + )); } + (Codec::Uuid, _) => Self::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)), }; Ok(match data_type.nullability() { Some(nullability) => Self::Nullable( @@ -924,20 +904,30 @@ impl Decoder { buf: &mut AvroCursor<'_>, promotion: Promotion, ) -> Result<(), ArrowError> { + macro_rules! promote_numeric_to { + ($variant:ident, $getter:ident, $to:ty) => {{ + match self { + Self::$variant(v) => { + let x = buf.$getter()?; + v.push(x as $to); + Ok(()) + } + other => Err(ArrowError::ParseError(format!( + "Promotion {promotion} target mismatch: expected {}, got {}", + stringify!($variant), + >::as_ref(other) + ))), + } + }}; + } match promotion { Promotion::Direct => self.decode(buf), - Promotion::IntToLong => promote_numeric!(self, buf, Int64, get_int, i64, promotion), - Promotion::IntToFloat => promote_numeric!(self, buf, Float32, get_int, f32, promotion), - Promotion::IntToDouble => promote_numeric!(self, buf, Float64, get_int, f64, promotion), - Promotion::LongToFloat => { - promote_numeric!(self, buf, Float32, get_long, f32, promotion) - } - Promotion::LongToDouble => { - promote_numeric!(self, buf, Float64, get_long, f64, promotion) - } - Promotion::FloatToDouble => { - promote_numeric!(self, buf, Float64, get_float, f64, promotion) - } + Promotion::IntToLong => promote_numeric_to!(Int64, get_int, i64), + Promotion::IntToFloat => promote_numeric_to!(Float32, get_int, f32), + Promotion::IntToDouble => promote_numeric_to!(Float64, get_int, f64), + Promotion::LongToFloat => promote_numeric_to!(Float32, get_long, f32), + Promotion::LongToDouble => promote_numeric_to!(Float64, get_long, f64), + Promotion::FloatToDouble => promote_numeric_to!(Float64, get_float, f64), Promotion::StringToBytes => match self { Self::Binary(offsets, values) | Self::StringToBytes(offsets, values) => { let data = buf.get_bytes()?; @@ -945,8 +935,9 @@ impl Decoder { values.extend_from_slice(data); Ok(()) } - _ => Err(ArrowError::ParseError(format!( - "Promotion {promotion} target mismatch", + other => Err(ArrowError::ParseError(format!( + "Promotion {promotion} target mismatch: expected bytes (Binary/StringToBytes), got {}", + >::as_ref(other) ))), }, Promotion::BytesToString => match self { @@ -958,8 +949,9 @@ impl Decoder { values.extend_from_slice(data); Ok(()) } - _ => Err(ArrowError::ParseError(format!( - "Promotion {promotion} target mismatch", + other => Err(ArrowError::ParseError(format!( + "Promotion {promotion} target mismatch: expected string (String/StringView/BytesToString), got {}", + >::as_ref(other) ))), }, } @@ -1110,12 +1102,12 @@ impl Decoder { } #[derive(Debug)] -struct DispatchLut { +struct DispatchLookupTable { to_reader: Box<[i16]>, promotion: Box<[Promotion]>, } -impl DispatchLut { +impl DispatchLookupTable { fn from_writer_to_reader(promotion_map: &[Option<(usize, Promotion)>]) -> Self { let mut to_reader = Vec::with_capacity(promotion_map.len()); let mut promotion = Vec::with_capacity(promotion_map.len()); @@ -1140,16 +1132,9 @@ impl DispatchLut { // Resolve a writer branch index to (reader_idx, promotion) #[inline] - fn resolve(&self, writer_idx: usize) -> Option<(usize, Promotion)> { - if writer_idx >= self.to_reader.len() { - return None; - } - let reader_index = self.to_reader[writer_idx]; - if reader_index < 0 { - None - } else { - Some((reader_index as usize, self.promotion[writer_idx])) - } + fn resolve(&self, writer_index: usize) -> Option<(usize, Promotion)> { + let reader_index = *self.to_reader.get(writer_index)?; + (reader_index >= 0).then(|| (reader_index as usize, self.promotion[writer_index])) } } @@ -1160,7 +1145,7 @@ struct UnionDecoder { offsets: Vec, branches: Vec, counts: Vec, - type_id_by_reader_idx: Arc<[i8]>, + type_id_by_reader_idx: Vec, null_branch: Option, default_emit_idx: usize, null_emit_idx: usize, @@ -1175,7 +1160,7 @@ impl Default for UnionDecoder { offsets: Vec::new(), branches: Vec::new(), counts: Vec::new(), - type_id_by_reader_idx: Arc::from([]), + type_id_by_reader_idx: Vec::new(), null_branch: None, default_emit_idx: 0, null_emit_idx: 0, @@ -1187,7 +1172,7 @@ impl Default for UnionDecoder { #[derive(Debug)] enum UnionReadPlan { ReaderUnion { - lookup_table: DispatchLut, + lookup_table: DispatchLookupTable, }, FromSingle { reader_idx: usize, @@ -1195,7 +1180,7 @@ enum UnionReadPlan { }, ToSingle { target: Box, - lookup_table: DispatchLut, + lookup_table: DispatchLookupTable, }, Passthrough, } @@ -1206,12 +1191,10 @@ impl UnionDecoder { branches: Vec, resolved: Option, ) -> Result { - let reader_type_codes: Arc<[i8]> = - Arc::from(fields.iter().map(|(tid, _)| tid).collect::>()); + let reader_type_codes = fields.iter().map(|(tid, _)| tid).collect::>(); let null_branch = branches.iter().position(|b| matches!(b, Decoder::Null(_))); let default_emit_idx = 0; let null_emit_idx = null_branch.unwrap_or(default_emit_idx); - let plan = Self::plan_from_resolved(resolved)?; let branch_len = branches.len().max(reader_type_codes.len()); Ok(Self { fields, @@ -1223,7 +1206,7 @@ impl UnionDecoder { null_branch, default_emit_idx, null_emit_idx, - plan, + plan: Self::plan_from_resolved(resolved)?, }) } @@ -1233,7 +1216,7 @@ impl UnionDecoder { ) -> Result { // This constructor is only for writer-union to single-type resolution debug_assert!(info.writer_is_union && !info.reader_is_union); - let lookup_table = DispatchLut::from_writer_to_reader(&info.writer_to_reader); + let lookup_table = DispatchLookupTable::from_writer_to_reader(&info.writer_to_reader); Ok(Self { plan: UnionReadPlan::ToSingle { target, @@ -1244,31 +1227,33 @@ impl UnionDecoder { } fn plan_from_resolved(resolved: Option) -> Result { - match resolved { - None => Ok(UnionReadPlan::Passthrough), - Some(info) => match (info.writer_is_union, info.reader_is_union) { - (true, true) => { - let lookup_table = DispatchLut::from_writer_to_reader(&info.writer_to_reader); - Ok(UnionReadPlan::ReaderUnion { lookup_table }) - } - (false, true) => { - let (reader_idx, promotion) = - info.writer_to_reader.first().and_then(|x| *x).ok_or_else(|| { - ArrowError::SchemaError( - "Writer type does not match any reader union branch".to_string(), - ) - })?; - Ok(UnionReadPlan::FromSingle { - reader_idx, - promotion, - }) - } - (true, false) => Err(ArrowError::InvalidArgumentError( - "UnionDecoder::try_new cannot build writer-union to single; use UnionDecoderBuilder with a target" - .to_string(), - )), - (false, false) => Ok(UnionReadPlan::Passthrough), - }, + let Some(info) = resolved else { + return Ok(UnionReadPlan::Passthrough); + }; + match (info.writer_is_union, info.reader_is_union) { + (true, true) => { + let lookup_table = + DispatchLookupTable::from_writer_to_reader(&info.writer_to_reader); + Ok(UnionReadPlan::ReaderUnion { lookup_table }) + } + (false, true) => { + let Some(&(reader_idx, promotion)) = + info.writer_to_reader.first().and_then(Option::as_ref) + else { + return Err(ArrowError::SchemaError( + "Writer type does not match any reader union branch".to_string(), + )); + }; + Ok(UnionReadPlan::FromSingle { + reader_idx, + promotion, + }) + } + (true, false) => Err(ArrowError::InvalidArgumentError( + "UnionDecoder::try_new cannot build writer-union to single; use UnionDecoderBuilder with a target" + .to_string(), + )), + (false, false) => Ok(UnionReadPlan::Passthrough), } } @@ -1322,18 +1307,6 @@ impl UnionDecoder { fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { let (reader_idx, promotion) = match &mut self.plan { - UnionReadPlan::ToSingle { - target, - lookup_table, - } => { - let idx = Self::read_tag(buf)?; - return match lookup_table.resolve(idx) { - Some((_, promotion)) => target.decode_with_promotion(buf, promotion), - None => Err(ArrowError::ParseError(format!( - "Writer union branch {idx} does not resolve to reader type" - ))), - }; - } UnionReadPlan::Passthrough => (Self::read_tag(buf)?, Promotion::Direct), UnionReadPlan::ReaderUnion { lookup_table } => { let idx = Self::read_tag(buf)?; @@ -1347,10 +1320,17 @@ impl UnionDecoder { reader_idx, promotion, } => (*reader_idx, *promotion), - UnionReadPlan::ToSingle { .. } => { - return Err(ArrowError::ParseError( - "Invalid union read plan state".to_string(), - )); + UnionReadPlan::ToSingle { + target, + lookup_table, + } => { + let idx = Self::read_tag(buf)?; + return match lookup_table.resolve(idx) { + Some((_, promotion)) => target.decode_with_promotion(buf, promotion), + None => Err(ArrowError::ParseError(format!( + "Writer union branch {idx} does not resolve to reader type" + ))), + }; } }; let decoder = self.emit_to(reader_idx)?; @@ -1358,31 +1338,29 @@ impl UnionDecoder { } fn flush(&mut self, nulls: Option) -> Result { - match &mut self.plan { - UnionReadPlan::ToSingle { target, .. } => target.flush(nulls), - _ => { - debug_assert!( - nulls.is_none(), - "UnionArray does not accept a validity bitmap; \ - nulls should have been materialized as a Null child during decode" - ); - let children = self - .branches - .iter_mut() - .map(|d| d.flush(None)) - .collect::, _>>()?; - let type_ids_buf = flush_values(&mut self.type_ids).into_iter().collect(); - let offsets_buf = flush_values(&mut self.offsets).into_iter().collect(); - let arr = UnionArray::try_new( - self.fields.clone(), - type_ids_buf, - Some(offsets_buf), - children, - ) - .map_err(|e| ArrowError::ParseError(e.to_string()))?; - Ok(Arc::new(arr)) - } + if let UnionReadPlan::ToSingle { target, .. } = &mut self.plan { + return target.flush(nulls); } + debug_assert!( + nulls.is_none(), + "UnionArray does not accept a validity bitmap; \ + nulls should have been materialized as a Null child during decode" + ); + let children = self + .branches + .iter_mut() + .map(|d| d.flush(None)) + .collect::, _>>()?; + let type_ids_buf = flush_values(&mut self.type_ids).into_iter().collect(); + let offsets_buf = flush_values(&mut self.offsets).into_iter().collect(); + let arr = UnionArray::try_new( + self.fields.clone(), + type_ids_buf, + Some(offsets_buf), + children, + ) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Ok(Arc::new(arr)) } } @@ -1421,7 +1399,7 @@ impl UnionDecoderBuilder { fn build(self) -> Result { match (self.resolved, self.fields, self.branches, self.target) { - (resolved, Some(fields), Some(branches), _) => { + (resolved, Some(fields), Some(branches), None) => { UnionDecoder::try_new(fields, branches, resolved) } (Some(info), None, None, Some(target)) @@ -1430,8 +1408,9 @@ impl UnionDecoderBuilder { UnionDecoder::try_new_from_writer_union(info, target) } _ => Err(ArrowError::InvalidArgumentError( - "UnionDecoderBuilder requires (fields + branches) for reader-unions \ - or (resolved + target) for writer-union to single" + "Invalid UnionDecoderBuilder configuration: expected either \ + (fields + branches + resolved) with no target for reader-unions, or \ + (resolved + target) with no fields/branches for writer-union to single." .to_string(), )), } From 2fe36aa17fdb7d5c2df91e5c9033e77bb7899f8b Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 22 Sep 2025 17:41:23 -0500 Subject: [PATCH 15/22] Changed `DispatchLookupTable` type ids to be `i8` instead of `i16`. --- arrow-avro/src/reader/record.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 4a2fed5faba5..853a28adc84d 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -1103,7 +1103,7 @@ impl Decoder { #[derive(Debug)] struct DispatchLookupTable { - to_reader: Box<[i16]>, + to_reader: Box<[i8]>, promotion: Box<[Promotion]>, } @@ -1114,8 +1114,8 @@ impl DispatchLookupTable { for map in promotion_map { match *map { Some((idx, promo)) => { - debug_assert!(idx <= i16::MAX as usize); - to_reader.push(idx as i16); + debug_assert!(idx <= i8::MAX as usize); + to_reader.push(idx as i8); promotion.push(promo); } None => { From f212a31dc35306e659d0665db9c2a68d2c182717 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 22 Sep 2025 18:56:25 -0500 Subject: [PATCH 16/22] Update arrow-avro/src/reader/record.rs Co-authored-by: Ryan Johnson --- arrow-avro/src/reader/record.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 853a28adc84d..2e1cb0aa14c5 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -443,7 +443,7 @@ impl Decoder { Box::new(val_dec), ) } - (Codec::Union(encodings, fields, mode), _) if *mode == UnionMode::Dense => { + (Codec::Union(encodings, fields, UnionMode::Dense), _) => { let decoders = encodings .iter() .map(Self::try_new_internal) From ebc5a1cba2bc76246cb6a8eb91d9af7093a5f15e Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 22 Sep 2025 18:57:19 -0500 Subject: [PATCH 17/22] Update arrow-avro/src/reader/record.rs Co-authored-by: Ryan Johnson --- arrow-avro/src/reader/record.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 2e1cb0aa14c5..203bd523aca6 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -1270,16 +1270,16 @@ impl UnionDecoder { #[inline] fn emit_to(&mut self, reader_idx: usize) -> Result<&mut Decoder, ArrowError> { - if reader_idx >= self.branches.len() { + let Some(reader_branch) = self.branches.get_mut(reader_idx) else { return Err(ArrowError::ParseError(format!( "Union branch index {reader_idx} out of range ({} branches)", self.branches.len() ))); - } + }; self.type_ids.push(self.type_id_by_reader_idx[reader_idx]); self.offsets.push(self.counts[reader_idx]); self.counts[reader_idx] += 1; - Ok(&mut self.branches[reader_idx]) + Ok(reader_branch) } #[inline] From 761adba2f9e0c41e54b1ce7ad0314c6b4d9402d0 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 22 Sep 2025 18:57:45 -0500 Subject: [PATCH 18/22] Update arrow-avro/src/reader/record.rs Co-authored-by: Ryan Johnson --- arrow-avro/src/reader/record.rs | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 203bd523aca6..f99804b7ab6a 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -264,7 +264,6 @@ enum Decoder { impl Decoder { fn try_new(data_type: &AvroDataType) -> Result { - // Extract just the Promotion (if any) to simplify pattern matching if let Some(ResolutionInfo::Union(info)) = data_type.resolution.as_ref() { if info.writer_is_union && !info.reader_is_union { let mut clone = data_type.clone(); @@ -283,6 +282,7 @@ impl Decoder { } fn try_new_internal(data_type: &AvroDataType) -> Result { + // Extract just the Promotion (if any) to simplify pattern matching let promotion = match data_type.resolution.as_ref() { Some(ResolutionInfo::Promotion(p)) => Some(p), _ => None, @@ -443,6 +443,7 @@ impl Decoder { Box::new(val_dec), ) } + (Codec::Uuid, _) => Self::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)), (Codec::Union(encodings, fields, UnionMode::Dense), _) => { let decoders = encodings .iter() @@ -470,7 +471,6 @@ impl Decoder { "Sparse Arrow unions are not yet supported".to_string(), )); } - (Codec::Uuid, _) => Self::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)), }; Ok(match data_type.nullability() { Some(nullability) => Self::Nullable( @@ -1270,10 +1270,10 @@ impl UnionDecoder { #[inline] fn emit_to(&mut self, reader_idx: usize) -> Result<&mut Decoder, ArrowError> { + let branches_len = self.branches.len(); let Some(reader_branch) = self.branches.get_mut(reader_idx) else { return Err(ArrowError::ParseError(format!( - "Union branch index {reader_idx} out of range ({} branches)", - self.branches.len() + "Union branch index {reader_idx} out of range ({branches_len} branches)" ))); }; self.type_ids.push(self.type_id_by_reader_idx[reader_idx]); @@ -1851,21 +1851,19 @@ impl Skipper { } Self::Union(encodings) => { // Union tag must be ZigZag-decoded - let branch = buf.get_long()?; - if branch < 0 { + let idx = buf.get_long()?; + if idx < 0 { return Err(ArrowError::ParseError(format!( - "Negative union branch index {branch}" + "Negative union branch index {idx}" ))); } - let idx = branch as usize; - if let Some(encoding) = encodings.get_mut(idx) { - encoding.skip(buf) - } else { - Err(ArrowError::ParseError(format!( + let Some(encoding) = encodings.get_mut(idx as usize) else { + return Err(ArrowError::ParseError(format!( "Union branch index {idx} out of range for skipper ({} branches)", encodings.len() - ))) - } + ))); + }; + encoding.skip(buf) } Self::Nullable(order, inner) => { let branch = buf.read_vlq()?; From 20d946fefa234a166776930dadc7eeefc93a76d3 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 22 Sep 2025 20:35:55 -0500 Subject: [PATCH 19/22] Address PR Comments --- arrow-avro/src/reader/record.rs | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index f99804b7ab6a..d965fafe72b7 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -945,6 +945,11 @@ impl Decoder { | Self::StringView(offsets, values) | Self::BytesToString(offsets, values) => { let data = buf.get_bytes()?; + std::str::from_utf8(data) + .map(|_| ()) + .map_err(|e| ArrowError::ParseError(format!( + "bytes->string promotion: invalid UTF-8 ({e})" + ))); offsets.push_length(data.len()); values.extend_from_slice(data); Ok(()) @@ -1055,7 +1060,7 @@ impl Decoder { other => { return Err(ArrowError::InvalidArgumentError(format!( "Map entries field must be a Struct, got {other:?}" - ))) + ))); } }; let entries_struct = @@ -1145,7 +1150,7 @@ struct UnionDecoder { offsets: Vec, branches: Vec, counts: Vec, - type_id_by_reader_idx: Vec, + reader_type_codes: Vec, null_branch: Option, default_emit_idx: usize, null_emit_idx: usize, @@ -1160,7 +1165,7 @@ impl Default for UnionDecoder { offsets: Vec::new(), branches: Vec::new(), counts: Vec::new(), - type_id_by_reader_idx: Vec::new(), + reader_type_codes: Vec::new(), null_branch: None, default_emit_idx: 0, null_emit_idx: 0, @@ -1202,7 +1207,7 @@ impl UnionDecoder { offsets: Vec::with_capacity(DEFAULT_CAPACITY), branches, counts: vec![0; branch_len], - type_id_by_reader_idx: reader_type_codes, + reader_type_codes, null_branch, default_emit_idx, null_emit_idx, @@ -1253,7 +1258,11 @@ impl UnionDecoder { "UnionDecoder::try_new cannot build writer-union to single; use UnionDecoderBuilder with a target" .to_string(), )), - (false, false) => Ok(UnionReadPlan::Passthrough), + // (false, false) is invalid and should never be constructed by the resolver. + _ => Err(ArrowError::SchemaError( + "ResolvedUnion constructed for non-union sides; resolver should return None" + .to_string(), + )), } } @@ -1276,7 +1285,7 @@ impl UnionDecoder { "Union branch index {reader_idx} out of range ({branches_len} branches)" ))); }; - self.type_ids.push(self.type_id_by_reader_idx[reader_idx]); + self.type_ids.push(self.reader_type_codes[reader_idx]); self.offsets.push(self.counts[reader_idx]); self.counts[reader_idx] += 1; Ok(reader_branch) @@ -1351,12 +1360,10 @@ impl UnionDecoder { .iter_mut() .map(|d| d.flush(None)) .collect::, _>>()?; - let type_ids_buf = flush_values(&mut self.type_ids).into_iter().collect(); - let offsets_buf = flush_values(&mut self.offsets).into_iter().collect(); let arr = UnionArray::try_new( self.fields.clone(), - type_ids_buf, - Some(offsets_buf), + flush_values(&mut self.type_ids).into_iter().collect(), + Some(flush_values(&mut self.offsets).into_iter().collect()), children, ) .map_err(|e| ArrowError::ParseError(e.to_string()))?; From 99971eb212c5f3faa4f4068acf7ed46e5031248b Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 23 Sep 2025 14:02:42 -0500 Subject: [PATCH 20/22] Update arrow-avro/src/reader/record.rs Co-authored-by: Andrew Lamb --- arrow-avro/src/reader/record.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index d965fafe72b7..1fb50eefcb23 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -1119,7 +1119,7 @@ impl DispatchLookupTable { for map in promotion_map { match *map { Some((idx, promo)) => { - debug_assert!(idx <= i8::MAX as usize); + let idx: i8 = idx.try_into().map_err(|e| ...)?; to_reader.push(idx as i8); promotion.push(promo); } From 26e8e96d4f331b7ec1202147a55d429b0869b4c9 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 23 Sep 2025 14:02:56 -0500 Subject: [PATCH 21/22] Update arrow-avro/src/reader/mod.rs Co-authored-by: Andrew Lamb --- arrow-avro/src/reader/mod.rs | 7 ++----- arrow-avro/src/reader/record.rs | 21 ++++++++++++++------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 23647d20a76c..68f30a709a71 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -1289,6 +1289,7 @@ mod test { ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int32Builder, Int64Builder, ListBuilder, MapBuilder, StringBuilder, StructBuilder, }; + use arrow_array::cast::AsArray; use arrow_array::types::{Int32Type, IntervalMonthDayNanoType}; use arrow_array::*; use arrow_buffer::{i256, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; @@ -2742,11 +2743,7 @@ mod test { let batch = read_file(path, 1024, false); let schema = batch.schema(); let idx = schema.index_of("nullable_int_nullfirst").unwrap(); - let a = batch - .column(idx) - .as_any() - .downcast_ref::() - .expect("nullable_int_nullfirst should be Int32"); + let a = batch.column(idx).as_primitive::(); assert_eq!(a.len(), 4); assert!(a.is_null(0)); assert_eq!(a.value(1), 42); diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 1fb50eefcb23..29ab8a853185 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -1113,14 +1113,21 @@ struct DispatchLookupTable { } impl DispatchLookupTable { - fn from_writer_to_reader(promotion_map: &[Option<(usize, Promotion)>]) -> Self { + fn from_writer_to_reader( + promotion_map: &[Option<(usize, Promotion)>], + ) -> Result { let mut to_reader = Vec::with_capacity(promotion_map.len()); let mut promotion = Vec::with_capacity(promotion_map.len()); for map in promotion_map { match *map { Some((idx, promo)) => { - let idx: i8 = idx.try_into().map_err(|e| ...)?; - to_reader.push(idx as i8); + let idx_i8 = i8::try_from(idx).map_err(|_| { + ArrowError::SchemaError(format!( + "Reader branch index {idx} exceeds i8 range (max {})", + i8::MAX + )) + })?; + to_reader.push(idx_i8); promotion.push(promo); } None => { @@ -1129,10 +1136,10 @@ impl DispatchLookupTable { } } } - Self { + Ok(Self { to_reader: to_reader.into_boxed_slice(), promotion: promotion.into_boxed_slice(), - } + }) } // Resolve a writer branch index to (reader_idx, promotion) @@ -1221,7 +1228,7 @@ impl UnionDecoder { ) -> Result { // This constructor is only for writer-union to single-type resolution debug_assert!(info.writer_is_union && !info.reader_is_union); - let lookup_table = DispatchLookupTable::from_writer_to_reader(&info.writer_to_reader); + let lookup_table = DispatchLookupTable::from_writer_to_reader(&info.writer_to_reader)?; Ok(Self { plan: UnionReadPlan::ToSingle { target, @@ -1238,7 +1245,7 @@ impl UnionDecoder { match (info.writer_is_union, info.reader_is_union) { (true, true) => { let lookup_table = - DispatchLookupTable::from_writer_to_reader(&info.writer_to_reader); + DispatchLookupTable::from_writer_to_reader(&info.writer_to_reader)?; Ok(UnionReadPlan::ReaderUnion { lookup_table }) } (false, true) => { From f3c97a004e14b5a493d6185657897480295e273a Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 23 Sep 2025 16:25:49 -0500 Subject: [PATCH 22/22] Address PR Comments --- arrow-avro/src/reader/mod.rs | 847 +++++++++++++++++++++++++++++++- arrow-avro/src/reader/record.rs | 106 +++- 2 files changed, 932 insertions(+), 21 deletions(-) diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 68f30a709a71..c9e4b1d22914 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -19,6 +19,19 @@ //! //! Facilities to read Apache Avro–encoded data into Arrow's `RecordBatch` format. //! +//! ### Limitations +//! +//!- **Avro unions with > 127 branches are not supported.** +//! When decoding Avro unions to Arrow `UnionArray`, Arrow stores the union +//! type identifiers in an **8‑bit signed** buffer (`i8`). This implies a +//! practical limit of **127** distinct branch ids. Inputs that resolve to +//! more than 127 branches will return an error. If you truly need more, +//! model the schema as a **union of unions**, per the Arrow format spec. +//! +//! See: Arrow Columnar Format — Dense Union (“types buffer: 8‑bit signed; +//! a union with more than 127 possible types can be modeled as a union of +//! unions”). +//! //! This module exposes three layers of the API surface, from highest to lowest-level: //! //! * [`ReaderBuilder`](crate::reader::ReaderBuilder): configures how Avro is read (batch size, strict union handling, @@ -1292,7 +1305,9 @@ mod test { use arrow_array::cast::AsArray; use arrow_array::types::{Int32Type, IntervalMonthDayNanoType}; use arrow_array::*; - use arrow_buffer::{i256, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow_buffer::{ + i256, Buffer, IntervalMonthDayNano, NullBuffer, OffsetBuffer, ScalarBuffer, + }; use arrow_schema::{ ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, Schema, UnionFields, UnionMode, }; @@ -3490,6 +3505,836 @@ mod test { } } + #[test] + fn test_union_fields_end_to_end_expected_arrays() { + fn tid_by_name(fields: &UnionFields, want: &str) -> i8 { + for (tid, f) in fields.iter() { + if f.name() == want { + return tid; + } + } + panic!("union child '{want}' not found") + } + + fn tid_by_dt(fields: &UnionFields, pred: impl Fn(&DataType) -> bool) -> i8 { + for (tid, f) in fields.iter() { + if pred(f.data_type()) { + return tid; + } + } + panic!("no union child matches predicate") + } + + fn uuid16_from_str(s: &str) -> [u8; 16] { + fn hex(b: u8) -> u8 { + match b { + b'0'..=b'9' => b - b'0', + b'a'..=b'f' => b - b'a' + 10, + b'A'..=b'F' => b - b'A' + 10, + _ => panic!("invalid hex"), + } + } + let mut out = [0u8; 16]; + let bytes = s.as_bytes(); + let (mut i, mut j) = (0, 0); + while i < bytes.len() { + if bytes[i] == b'-' { + i += 1; + continue; + } + let hi = hex(bytes[i]); + let lo = hex(bytes[i + 1]); + out[j] = (hi << 4) | lo; + j += 1; + i += 2; + } + assert_eq!(j, 16, "uuid must decode to 16 bytes"); + out + } + + fn empty_child_for(dt: &DataType) -> Arc { + match dt { + DataType::Null => Arc::new(NullArray::new(0)), + DataType::Boolean => Arc::new(BooleanArray::from(Vec::::new())), + DataType::Int32 => Arc::new(Int32Array::from(Vec::::new())), + DataType::Int64 => Arc::new(Int64Array::from(Vec::::new())), + DataType::Float32 => Arc::new(arrow_array::Float32Array::from(Vec::::new())), + DataType::Float64 => Arc::new(arrow_array::Float64Array::from(Vec::::new())), + DataType::Binary => Arc::new(BinaryArray::from(Vec::<&[u8]>::new())), + DataType::Utf8 => Arc::new(StringArray::from(Vec::<&str>::new())), + DataType::Date32 => Arc::new(arrow_array::Date32Array::from(Vec::::new())), + DataType::Time32(arrow_schema::TimeUnit::Millisecond) => { + Arc::new(Time32MillisecondArray::from(Vec::::new())) + } + DataType::Time64(arrow_schema::TimeUnit::Microsecond) => { + Arc::new(Time64MicrosecondArray::from(Vec::::new())) + } + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz) => { + let a = TimestampMillisecondArray::from(Vec::::new()); + Arc::new(if let Some(tz) = tz { + a.with_timezone(tz.clone()) + } else { + a + }) + } + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz) => { + let a = TimestampMicrosecondArray::from(Vec::::new()); + Arc::new(if let Some(tz) = tz { + a.with_timezone(tz.clone()) + } else { + a + }) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + Arc::new(arrow_array::IntervalMonthDayNanoArray::from(Vec::< + IntervalMonthDayNano, + >::new( + ))) + } + DataType::FixedSizeBinary(n) => Arc::new(FixedSizeBinaryArray::new_null(*n, 0)), + DataType::Dictionary(k, v) => { + assert_eq!(**k, DataType::Int32, "expect int32 keys for enums"); + let keys = Int32Array::from(Vec::::new()); + let values = match v.as_ref() { + DataType::Utf8 => { + Arc::new(StringArray::from(Vec::<&str>::new())) as ArrayRef + } + other => panic!("unexpected dictionary value type {other:?}"), + }; + Arc::new(DictionaryArray::::try_new(keys, values).unwrap()) + } + DataType::List(field) => { + let values: ArrayRef = match field.data_type() { + DataType::Int32 => { + Arc::new(Int32Array::from(Vec::::new())) as ArrayRef + } + DataType::Int64 => { + Arc::new(Int64Array::from(Vec::::new())) as ArrayRef + } + DataType::Utf8 => { + Arc::new(StringArray::from(Vec::<&str>::new())) as ArrayRef + } + DataType::Union(_, _) => { + let (uf, _) = if let DataType::Union(f, m) = field.data_type() { + (f.clone(), m) + } else { + unreachable!() + }; + let children: Vec = uf + .iter() + .map(|(_, f)| empty_child_for(f.data_type())) + .collect(); + Arc::new( + UnionArray::try_new( + uf.clone(), + ScalarBuffer::::from(Vec::::new()), + Some(ScalarBuffer::::from(Vec::::new())), + children, + ) + .unwrap(), + ) as ArrayRef + } + other => panic!("unsupported list item type: {other:?}"), + }; + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0])); + Arc::new(ListArray::try_new(field.clone(), offsets, values, None).unwrap()) + } + DataType::Map(entry_field, ordered) => { + let DataType::Struct(childs) = entry_field.data_type() else { + panic!("map entries must be struct") + }; + let key_field = &childs[0]; + let val_field = &childs[1]; + assert_eq!(key_field.data_type(), &DataType::Utf8); + let keys = StringArray::from(Vec::<&str>::new()); + let vals: ArrayRef = match val_field.data_type() { + DataType::Float64 => { + Arc::new(arrow_array::Float64Array::from(Vec::::new())) as ArrayRef + } + DataType::Int64 => { + Arc::new(Int64Array::from(Vec::::new())) as ArrayRef + } + DataType::Utf8 => { + Arc::new(StringArray::from(Vec::<&str>::new())) as ArrayRef + } + DataType::Union(uf, _) => { + let ch: Vec = uf + .iter() + .map(|(_, f)| empty_child_for(f.data_type())) + .collect(); + Arc::new( + UnionArray::try_new( + uf.clone(), + ScalarBuffer::::from(Vec::::new()), + Some(ScalarBuffer::::from(Vec::::new())), + ch, + ) + .unwrap(), + ) as ArrayRef + } + other => panic!("unsupported map value type: {other:?}"), + }; + let entries = StructArray::new( + Fields::from(vec![key_field.as_ref().clone(), val_field.as_ref().clone()]), + vec![Arc::new(keys) as ArrayRef, vals], + None, + ); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0])); + Arc::new(MapArray::new( + entry_field.clone(), + offsets, + entries, + None, + *ordered, + )) + } + other => panic!("empty_child_for: unhandled type {other:?}"), + } + } + + fn mk_dense_union( + fields: &UnionFields, + type_ids: Vec, + offsets: Vec, + provide: impl Fn(&Field) -> Option, + ) -> ArrayRef { + let children: Vec = fields + .iter() + .map(|(_, f)| provide(f).unwrap_or_else(|| empty_child_for(f.data_type()))) + .collect(); + + Arc::new( + UnionArray::try_new( + fields.clone(), + ScalarBuffer::::from(type_ids), + Some(ScalarBuffer::::from(offsets)), + children, + ) + .unwrap(), + ) as ArrayRef + } + + // Dates / times / timestamps from the Avro content block: + let date_a: i32 = 19_000; + let time_ms_a: i32 = 13 * 3_600_000 + 45 * 60_000 + 30_000 + 123; + let time_us_b: i64 = 23 * 3_600_000_000 + 59 * 60_000_000 + 59 * 1_000_000 + 999_999; + let ts_ms_2024_01_01: i64 = 1_704_067_200_000; + let ts_us_2024_01_01: i64 = ts_ms_2024_01_01 * 1000; + // Fixed / bytes-like values: + let fx8_a: [u8; 8] = *b"ABCDEFGH"; + let fx4_abcd: [u8; 4] = *b"ABCD"; + let fx4_misc: [u8; 4] = [0x00, 0x11, 0x22, 0x33]; + let fx10_ascii: [u8; 10] = *b"0123456789"; + let fx10_aa: [u8; 10] = [0xAA; 10]; + // Duration logical values as MonthDayNano: + let dur_a = IntervalMonthDayNanoType::make_value(1, 2, 3_000_000_000); + let dur_b = IntervalMonthDayNanoType::make_value(12, 31, 999_000_000); + // UUID logical values (stored as 16-byte FixedSizeBinary in Arrow): + let uuid1 = uuid16_from_str("fe7bc30b-4ce8-4c5e-b67c-2234a2d38e66"); + let uuid2 = uuid16_from_str("0826cc06-d2e3-4599-b4ad-af5fa6905cdb"); + // Decimals from Avro content: + let dec_b_scale2_pos: i128 = 123_456; // "1234.56" bytes-decimal -> (precision=10, scale=2) + let dec_fix16_neg: i128 = -101; // "-1.01" fixed(16) decimal(10,2) + let dec_fix20_s4: i128 = 1_234_567_891_234; // "123456789.1234" fixed(20) decimal(20,4) + let dec_fix20_s4_neg: i128 = -123; // "-0.0123" fixed(20) decimal(20,4) + let path = "test/data/union_fields.avro"; + let actual = read_file(path, 1024, false); + let schema = actual.schema(); + // Helper to fetch union metadata for a column + let get_union = |name: &str| -> (UnionFields, UnionMode) { + let idx = schema.index_of(name).unwrap(); + match schema.field(idx).data_type() { + DataType::Union(f, m) => (f.clone(), *m), + other => panic!("{name} should be a Union, got {other:?}"), + } + }; + let mut expected_cols: Vec = Vec::with_capacity(schema.fields().len()); + // 1) ["null","int"]: Int32 (nullable) + expected_cols.push(Arc::new(Int32Array::from(vec![ + None, + Some(42), + None, + Some(0), + ]))); + // 2) ["string","null"]: Utf8 (nullable) + expected_cols.push(Arc::new(StringArray::from(vec![ + Some("s1"), + None, + Some("s3"), + Some(""), + ]))); + // 3) union_prim: ["boolean","int","long","float","double","bytes","string"] + { + let (uf, mode) = get_union("union_prim"); + assert!(matches!(mode, UnionMode::Dense)); + let tids = vec![ + tid_by_name(&uf, "long"), + tid_by_name(&uf, "int"), + tid_by_name(&uf, "float"), + tid_by_name(&uf, "double"), + ]; + let offs = vec![0, 0, 0, 0]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.name().as_str() { + "int" => Some(Arc::new(Int32Array::from(vec![-1])) as ArrayRef), + "long" => Some(Arc::new(Int64Array::from(vec![1_234_567_890_123i64])) as ArrayRef), + "float" => { + Some(Arc::new(arrow_array::Float32Array::from(vec![1.25f32])) as ArrayRef) + } + "double" => { + Some(Arc::new(arrow_array::Float64Array::from(vec![-2.5f64])) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 4) union_bytes_vs_string: ["bytes","string"] + { + let (uf, _) = get_union("union_bytes_vs_string"); + let tids = vec![ + tid_by_name(&uf, "bytes"), + tid_by_name(&uf, "string"), + tid_by_name(&uf, "string"), + tid_by_name(&uf, "bytes"), + ]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.name().as_str() { + "bytes" => Some( + Arc::new(BinaryArray::from(vec![&[0x00, 0xFF, 0x7F][..], &[][..]])) as ArrayRef, + ), + "string" => Some(Arc::new(StringArray::from(vec!["hello", "world"])) as ArrayRef), + _ => None, + }); + expected_cols.push(arr); + } + // 5) union_fixed_dur_decfix: [Fx8, Dur12, DecFix16(decimal(10,2))] + { + let (uf, _) = get_union("union_fixed_dur_decfix"); + let tid_fx8 = tid_by_dt(&uf, |dt| matches!(dt, DataType::FixedSizeBinary(8))); + let tid_dur = tid_by_dt(&uf, |dt| { + matches!( + dt, + DataType::Interval(arrow_schema::IntervalUnit::MonthDayNano) + ) + }); + let tid_dec = tid_by_dt(&uf, |dt| match dt { + #[cfg(feature = "small_decimals")] + DataType::Decimal64(10, 2) => true, + DataType::Decimal128(10, 2) | DataType::Decimal256(10, 2) => true, + _ => false, + }); + let tids = vec![tid_fx8, tid_dur, tid_dec, tid_dur]; + let offs = vec![0, 0, 0, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::FixedSizeBinary(8) => { + let it = [Some(fx8_a)].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 8).unwrap(), + ) as ArrayRef) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + Some(Arc::new(arrow_array::IntervalMonthDayNanoArray::from(vec![ + dur_a, dur_b, + ])) as ArrayRef) + } + #[cfg(feature = "small_decimals")] + DataType::Decimal64(10, 2) => { + let a = arrow_array::Decimal64Array::from_iter_values([dec_fix16_neg as i64]); + Some(Arc::new(a.with_precision_and_scale(10, 2).unwrap()) as ArrayRef) + } + DataType::Decimal128(10, 2) => { + let a = arrow_array::Decimal128Array::from_iter_values([dec_fix16_neg]); + Some(Arc::new(a.with_precision_and_scale(10, 2).unwrap()) as ArrayRef) + } + DataType::Decimal256(10, 2) => { + let a = arrow_array::Decimal256Array::from_iter_values([i256::from_i128( + dec_fix16_neg, + )]); + Some(Arc::new(a.with_precision_and_scale(10, 2).unwrap()) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 6) union_enum_records_array_map: [enum ColorU, record RecA, record RecB, array, map] + { + let (uf, _) = get_union("union_enum_records_array_map"); + let tid_enum = tid_by_dt(&uf, |dt| matches!(dt, DataType::Dictionary(_, _))); + let tid_reca = tid_by_dt(&uf, |dt| { + if let DataType::Struct(fs) = dt { + fs.len() == 2 && fs[0].name() == "a" && fs[1].name() == "b" + } else { + false + } + }); + let tid_recb = tid_by_dt(&uf, |dt| { + if let DataType::Struct(fs) = dt { + fs.len() == 2 && fs[0].name() == "x" && fs[1].name() == "y" + } else { + false + } + }); + let tid_arr = tid_by_dt(&uf, |dt| matches!(dt, DataType::List(_))); + let tids = vec![tid_enum, tid_reca, tid_recb, tid_arr]; + let offs = vec![0, 0, 0, 0]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Dictionary(_, _) => { + let keys = Int32Array::from(vec![0i32]); // "RED" + let values = + Arc::new(StringArray::from(vec!["RED", "GREEN", "BLUE"])) as ArrayRef; + Some( + Arc::new(DictionaryArray::::try_new(keys, values).unwrap()) + as ArrayRef, + ) + } + DataType::Struct(fs) + if fs.len() == 2 && fs[0].name() == "a" && fs[1].name() == "b" => + { + let a = Int32Array::from(vec![7]); + let b = StringArray::from(vec!["x"]); + Some(Arc::new(StructArray::new( + fs.clone(), + vec![Arc::new(a), Arc::new(b)], + None, + )) as ArrayRef) + } + DataType::Struct(fs) + if fs.len() == 2 && fs[0].name() == "x" && fs[1].name() == "y" => + { + let x = Int64Array::from(vec![123_456_789i64]); + let y = BinaryArray::from(vec![&[0xFF, 0x00][..]]); + Some(Arc::new(StructArray::new( + fs.clone(), + vec![Arc::new(x), Arc::new(y)], + None, + )) as ArrayRef) + } + DataType::List(field) => { + let values = Int64Array::from(vec![1i64, 2, 3]); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 3])); + Some(Arc::new( + ListArray::try_new(field.clone(), offsets, Arc::new(values), None).unwrap(), + ) as ArrayRef) + } + DataType::Map(_, _) => None, + other => panic!("unexpected child {other:?}"), + }); + expected_cols.push(arr); + } + // 7) union_date_or_fixed4: [date32, fixed(4)] + { + let (uf, _) = get_union("union_date_or_fixed4"); + let tid_date = tid_by_dt(&uf, |dt| matches!(dt, DataType::Date32)); + let tid_fx4 = tid_by_dt(&uf, |dt| matches!(dt, DataType::FixedSizeBinary(4))); + let tids = vec![tid_date, tid_fx4, tid_date, tid_fx4]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Date32 => { + Some(Arc::new(arrow_array::Date32Array::from(vec![date_a, 0])) as ArrayRef) + } + DataType::FixedSizeBinary(4) => { + let it = [Some(fx4_abcd), Some(fx4_misc)].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 4).unwrap(), + ) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 8) union_time_millis_or_enum: [time-millis, enum OnOff] + { + let (uf, _) = get_union("union_time_millis_or_enum"); + let tid_ms = tid_by_dt(&uf, |dt| { + matches!(dt, DataType::Time32(arrow_schema::TimeUnit::Millisecond)) + }); + let tid_en = tid_by_dt(&uf, |dt| matches!(dt, DataType::Dictionary(_, _))); + let tids = vec![tid_ms, tid_en, tid_en, tid_ms]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Time32(arrow_schema::TimeUnit::Millisecond) => { + Some(Arc::new(Time32MillisecondArray::from(vec![time_ms_a, 0])) as ArrayRef) + } + DataType::Dictionary(_, _) => { + let keys = Int32Array::from(vec![0i32, 1]); // "ON", "OFF" + let values = Arc::new(StringArray::from(vec!["ON", "OFF"])) as ArrayRef; + Some( + Arc::new(DictionaryArray::::try_new(keys, values).unwrap()) + as ArrayRef, + ) + } + _ => None, + }); + expected_cols.push(arr); + } + // 9) union_time_micros_or_string: [time-micros, string] + { + let (uf, _) = get_union("union_time_micros_or_string"); + let tid_us = tid_by_dt(&uf, |dt| { + matches!(dt, DataType::Time64(arrow_schema::TimeUnit::Microsecond)) + }); + let tid_s = tid_by_name(&uf, "string"); + let tids = vec![tid_s, tid_us, tid_s, tid_s]; + let offs = vec![0, 0, 1, 2]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Time64(arrow_schema::TimeUnit::Microsecond) => { + Some(Arc::new(Time64MicrosecondArray::from(vec![time_us_b])) as ArrayRef) + } + DataType::Utf8 => { + Some(Arc::new(StringArray::from(vec!["evening", "night", ""])) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 10) union_ts_millis_utc_or_array: [timestamp-millis(TZ), array] + { + let (uf, _) = get_union("union_ts_millis_utc_or_array"); + let tid_ts = tid_by_dt(&uf, |dt| { + matches!( + dt, + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _) + ) + }); + let tid_arr = tid_by_dt(&uf, |dt| matches!(dt, DataType::List(_))); + let tids = vec![tid_ts, tid_arr, tid_arr, tid_ts]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz) => { + let a = TimestampMillisecondArray::from(vec![ + ts_ms_2024_01_01, + ts_ms_2024_01_01 + 86_400_000, + ]); + Some(Arc::new(if let Some(tz) = tz { + a.with_timezone(tz.clone()) + } else { + a + }) as ArrayRef) + } + DataType::List(field) => { + let values = Int32Array::from(vec![0, 1, 2, -1, 0, 1]); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 3, 6])); + Some(Arc::new( + ListArray::try_new(field.clone(), offsets, Arc::new(values), None).unwrap(), + ) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 11) union_ts_micros_local_or_bytes: [local-timestamp-micros, bytes] + { + let (uf, _) = get_union("union_ts_micros_local_or_bytes"); + let tid_lts = tid_by_dt(&uf, |dt| { + matches!( + dt, + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None) + ) + }); + let tid_b = tid_by_name(&uf, "bytes"); + let tids = vec![tid_b, tid_lts, tid_b, tid_b]; + let offs = vec![0, 0, 1, 2]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None) => Some(Arc::new( + TimestampMicrosecondArray::from(vec![ts_us_2024_01_01]), + ) + as ArrayRef), + DataType::Binary => Some(Arc::new(BinaryArray::from(vec![ + &b"\x11\x22\x33"[..], + &b"\x00"[..], + &b"\x10\x20\x30\x40"[..], + ])) as ArrayRef), + _ => None, + }); + expected_cols.push(arr); + } + // 12) union_uuid_or_fixed10: [uuid(string)->fixed(16), fixed(10)] + { + let (uf, _) = get_union("union_uuid_or_fixed10"); + let tid_fx16 = tid_by_dt(&uf, |dt| matches!(dt, DataType::FixedSizeBinary(16))); + let tid_fx10 = tid_by_dt(&uf, |dt| matches!(dt, DataType::FixedSizeBinary(10))); + let tids = vec![tid_fx16, tid_fx10, tid_fx16, tid_fx10]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::FixedSizeBinary(16) => { + let it = [Some(uuid1), Some(uuid2)].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 16).unwrap(), + ) as ArrayRef) + } + DataType::FixedSizeBinary(10) => { + let it = [Some(fx10_ascii), Some(fx10_aa)].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 10).unwrap(), + ) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 13) union_dec_bytes_or_dec_fixed: [bytes dec(10,2), fixed(20) dec(20,4)] + { + let (uf, _) = get_union("union_dec_bytes_or_dec_fixed"); + let tid_b10s2 = tid_by_dt(&uf, |dt| match dt { + #[cfg(feature = "small_decimals")] + DataType::Decimal64(10, 2) => true, + DataType::Decimal128(10, 2) | DataType::Decimal256(10, 2) => true, + _ => false, + }); + let tid_f20s4 = tid_by_dt(&uf, |dt| { + matches!( + dt, + DataType::Decimal128(20, 4) | DataType::Decimal256(20, 4) + ) + }); + let tids = vec![tid_b10s2, tid_f20s4, tid_b10s2, tid_f20s4]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + #[cfg(feature = "small_decimals")] + DataType::Decimal64(10, 2) => { + let a = Decimal64Array::from_iter_values([dec_b_scale2_pos as i64, 0i64]); + Some(Arc::new(a.with_precision_and_scale(10, 2).unwrap()) as ArrayRef) + } + DataType::Decimal128(10, 2) => { + let a = Decimal128Array::from_iter_values([dec_b_scale2_pos, 0]); + Some(Arc::new(a.with_precision_and_scale(10, 2).unwrap()) as ArrayRef) + } + DataType::Decimal256(10, 2) => { + let a = Decimal256Array::from_iter_values([ + i256::from_i128(dec_b_scale2_pos), + i256::from(0), + ]); + Some(Arc::new(a.with_precision_and_scale(10, 2).unwrap()) as ArrayRef) + } + DataType::Decimal128(20, 4) => { + let a = Decimal128Array::from_iter_values([dec_fix20_s4_neg, dec_fix20_s4]); + Some(Arc::new(a.with_precision_and_scale(20, 4).unwrap()) as ArrayRef) + } + DataType::Decimal256(20, 4) => { + let a = Decimal256Array::from_iter_values([ + i256::from_i128(dec_fix20_s4_neg), + i256::from_i128(dec_fix20_s4), + ]); + Some(Arc::new(a.with_precision_and_scale(20, 4).unwrap()) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 14) union_null_bytes_string: ["null","bytes","string"] + { + let (uf, _) = get_union("union_null_bytes_string"); + let tid_n = tid_by_name(&uf, "null"); + let tid_b = tid_by_name(&uf, "bytes"); + let tid_s = tid_by_name(&uf, "string"); + let tids = vec![tid_n, tid_b, tid_s, tid_s]; + let offs = vec![0, 0, 0, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.name().as_str() { + "null" => Some(Arc::new(arrow_array::NullArray::new(1)) as ArrayRef), + "bytes" => Some(Arc::new(BinaryArray::from(vec![&b"\x01\x02"[..]])) as ArrayRef), + "string" => Some(Arc::new(StringArray::from(vec!["text", "u"])) as ArrayRef), + _ => None, + }); + expected_cols.push(arr); + } + // 15) array_of_union: array<[long,string]> + { + let idx = schema.index_of("array_of_union").unwrap(); + let dt = schema.field(idx).data_type().clone(); + let (item_field, _) = match &dt { + DataType::List(f) => (f.clone(), ()), + other => panic!("array_of_union must be List, got {other:?}"), + }; + let (uf, _) = match item_field.data_type() { + DataType::Union(f, m) => (f.clone(), m), + other => panic!("array_of_union items must be Union, got {other:?}"), + }; + let tid_l = tid_by_name(&uf, "long"); + let tid_s = tid_by_name(&uf, "string"); + let type_ids = vec![tid_l, tid_s, tid_l, tid_s, tid_l, tid_l, tid_s, tid_l]; + let offsets = vec![0, 0, 1, 1, 2, 3, 2, 4]; + let values_union = + mk_dense_union(&uf, type_ids, offsets, |f| match f.name().as_str() { + "long" => { + Some(Arc::new(Int64Array::from(vec![1i64, -5, 42, -1, 0])) as ArrayRef) + } + "string" => Some(Arc::new(StringArray::from(vec!["a", "", "z"])) as ArrayRef), + _ => None, + }); + let list_offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 3, 5, 6, 8])); + expected_cols.push(Arc::new( + ListArray::try_new(item_field.clone(), list_offsets, values_union, None).unwrap(), + )); + } + // 16) map_of_union: map<[null,double]> + { + let idx = schema.index_of("map_of_union").unwrap(); + let dt = schema.field(idx).data_type().clone(); + let (entry_field, ordered) = match &dt { + DataType::Map(f, ordered) => (f.clone(), *ordered), + other => panic!("map_of_union must be Map, got {other:?}"), + }; + let DataType::Struct(entry_fields) = entry_field.data_type() else { + panic!("map entries must be struct") + }; + let key_field = entry_fields[0].clone(); + let val_field = entry_fields[1].clone(); + let keys = StringArray::from(vec!["a", "b", "x", "pi"]); + let rounded_pi = (std::f64::consts::PI * 100_000.0).round() / 100_000.0; + let values: ArrayRef = match val_field.data_type() { + DataType::Union(uf, _) => { + let tid_n = tid_by_name(uf, "null"); + let tid_d = tid_by_name(uf, "double"); + let tids = vec![tid_n, tid_d, tid_d, tid_d]; + let offs = vec![0, 0, 1, 2]; + mk_dense_union(uf, tids, offs, |f| match f.name().as_str() { + "null" => Some(Arc::new(NullArray::new(1)) as ArrayRef), + "double" => Some(Arc::new(arrow_array::Float64Array::from(vec![ + 2.5f64, -0.5f64, rounded_pi, + ])) as ArrayRef), + _ => None, + }) + } + DataType::Float64 => Arc::new(arrow_array::Float64Array::from(vec![ + None, + Some(2.5), + Some(-0.5), + Some(rounded_pi), + ])), + other => panic!("unexpected map value type {other:?}"), + }; + let entries = StructArray::new( + Fields::from(vec![key_field.as_ref().clone(), val_field.as_ref().clone()]), + vec![Arc::new(keys) as ArrayRef, values], + None, + ); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 2, 3, 3, 4])); + expected_cols.push(Arc::new(MapArray::new( + entry_field, + offsets, + entries, + None, + ordered, + ))); + } + // 17) record_with_union_field: struct { id:int, u:[int,string] } + { + let idx = schema.index_of("record_with_union_field").unwrap(); + let DataType::Struct(rec_fields) = schema.field(idx).data_type() else { + panic!("record_with_union_field should be Struct") + }; + let id = Int32Array::from(vec![1, 2, 3, 4]); + let u_field = rec_fields.iter().find(|f| f.name() == "u").unwrap(); + let DataType::Union(uf, _) = u_field.data_type() else { + panic!("u must be Union") + }; + let tid_i = tid_by_name(uf, "int"); + let tid_s = tid_by_name(uf, "string"); + let tids = vec![tid_s, tid_i, tid_i, tid_s]; + let offs = vec![0, 0, 1, 1]; + let u = mk_dense_union(uf, tids, offs, |f| match f.name().as_str() { + "int" => Some(Arc::new(Int32Array::from(vec![99, 0])) as ArrayRef), + "string" => Some(Arc::new(StringArray::from(vec!["one", "four"])) as ArrayRef), + _ => None, + }); + let rec = StructArray::new(rec_fields.clone(), vec![Arc::new(id) as ArrayRef, u], None); + expected_cols.push(Arc::new(rec)); + } + // 18) union_ts_micros_utc_or_map: [timestamp-micros(TZ), map] + { + let (uf, _) = get_union("union_ts_micros_utc_or_map"); + let tid_ts = tid_by_dt(&uf, |dt| { + matches!( + dt, + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, Some(_)) + ) + }); + let tid_map = tid_by_dt(&uf, |dt| matches!(dt, DataType::Map(_, _))); + let tids = vec![tid_ts, tid_map, tid_ts, tid_map]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz) => { + let a = TimestampMicrosecondArray::from(vec![ts_us_2024_01_01, 0i64]); + Some(Arc::new(if let Some(tz) = tz { + a.with_timezone(tz.clone()) + } else { + a + }) as ArrayRef) + } + DataType::Map(entry_field, ordered) => { + let DataType::Struct(fs) = entry_field.data_type() else { + panic!("map entries must be struct") + }; + let key_field = fs[0].clone(); + let val_field = fs[1].clone(); + assert_eq!(key_field.data_type(), &DataType::Utf8); + assert_eq!(val_field.data_type(), &DataType::Int64); + let keys = StringArray::from(vec!["k1", "k2", "n"]); + let vals = Int64Array::from(vec![1i64, 2, 0]); + let entries = StructArray::new( + Fields::from(vec![key_field.as_ref().clone(), val_field.as_ref().clone()]), + vec![Arc::new(keys) as ArrayRef, Arc::new(vals) as ArrayRef], + None, + ); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 2, 3])); + Some(Arc::new(MapArray::new( + entry_field.clone(), + offsets, + entries, + None, + *ordered, + )) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 19) union_ts_millis_local_or_string: [local-timestamp-millis, string] + { + let (uf, _) = get_union("union_ts_millis_local_or_string"); + let tid_ts = tid_by_dt(&uf, |dt| { + matches!( + dt, + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None) + ) + }); + let tid_s = tid_by_name(&uf, "string"); + let tids = vec![tid_s, tid_ts, tid_s, tid_s]; + let offs = vec![0, 0, 1, 2]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None) => Some(Arc::new( + TimestampMillisecondArray::from(vec![ts_ms_2024_01_01]), + ) + as ArrayRef), + DataType::Utf8 => { + Some( + Arc::new(StringArray::from(vec!["local midnight", "done", ""])) as ArrayRef, + ) + } + _ => None, + }); + expected_cols.push(arr); + } + // 20) union_bool_or_string: ["boolean","string"] + { + let (uf, _) = get_union("union_bool_or_string"); + let tid_b = tid_by_name(&uf, "boolean"); + let tid_s = tid_by_name(&uf, "string"); + let tids = vec![tid_b, tid_s, tid_b, tid_s]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.name().as_str() { + "boolean" => Some(Arc::new(BooleanArray::from(vec![true, false])) as ArrayRef), + "string" => Some(Arc::new(StringArray::from(vec!["no", "yes"])) as ArrayRef), + _ => None, + }); + expected_cols.push(arr); + } + let expected = RecordBatch::try_new(schema.clone(), expected_cols).unwrap(); + assert_eq!( + actual, expected, + "full end-to-end equality for union_fields.avro" + ); + } + #[test] fn test_read_zero_byte_avro_file() { let batch = read_file("test/data/zero_byte.avro", 3, false); diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 29ab8a853185..950333174b26 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -456,6 +456,17 @@ impl Decoder { decoders.len() ))); } + // Proactive guard: if a user provides a union with more branches than + // a 32-bit Avro index can address, fail fast with a clear message. + let branch_count = decoders.len(); + let max_addr = (i32::MAX as usize) + 1; + if branch_count > max_addr { + return Err(ArrowError::SchemaError(format!( + "Union has {branch_count} branches, which exceeds the maximum addressable \ + branches by an Avro int tag ({} + 1).", + i32::MAX + ))); + } let mut builder = UnionDecoderBuilder::new() .with_fields(fields.clone()) .with_branches(decoders); @@ -945,11 +956,6 @@ impl Decoder { | Self::StringView(offsets, values) | Self::BytesToString(offsets, values) => { let data = buf.get_bytes()?; - std::str::from_utf8(data) - .map(|_| ()) - .map_err(|e| ArrowError::ParseError(format!( - "bytes->string promotion: invalid UTF-8 ({e})" - ))); offsets.push_length(data.len()); values.extend_from_slice(data); Ok(()) @@ -1106,12 +1112,36 @@ impl Decoder { } } +// A lookup table for resolving fields between writer and reader schemas during record projection. #[derive(Debug)] struct DispatchLookupTable { + // Maps each reader field index `r` to the corresponding writer field index. + // + // Semantics: + // - `to_reader[r] >= 0`: The value is an index into the writer's fields. The value from + // the writer field is decoded, and `promotion[r]` is applied. + // - `to_reader[r] == NO_SOURCE` (-1): No matching writer field exists. The reader field's + // default value is used. + // + // Representation (`i8`): + // `i8` is used for a dense, cache-friendly dispatch table, consistent with Arrow's use of + // `i8` for union type IDs. This requires that writer field indices do not exceed `i8::MAX`. + // + // Invariants: + // - `to_reader.len() == promotion.len()` and matches the reader field count. + // - If `to_reader[r] == NO_SOURCE`, `promotion[r]` is ignored. to_reader: Box<[i8]>, + // For each reader field `r`, specifies the `Promotion` to apply to the writer's value. + // + // This is used when a writer field's type can be promoted to a reader field's type + // (e.g., `Int` to `Long`). It is ignored if `to_reader[r] == NO_SOURCE`. promotion: Box<[Promotion]>, } +// Sentinel used in `DispatchLookupTable::to_reader` to mark +// "no matching writer field". +const NO_SOURCE: i8 = -1; + impl DispatchLookupTable { fn from_writer_to_reader( promotion_map: &[Option<(usize, Promotion)>], @@ -1131,7 +1161,7 @@ impl DispatchLookupTable { promotion.push(promo); } None => { - to_reader.push(-1); + to_reader.push(NO_SOURCE); promotion.push(Promotion::Direct); } } @@ -1208,6 +1238,16 @@ impl UnionDecoder { let default_emit_idx = 0; let null_emit_idx = null_branch.unwrap_or(default_emit_idx); let branch_len = branches.len().max(reader_type_codes.len()); + // Guard against impractically large unions that cannot be indexed by an Avro int + let max_addr = (i32::MAX as usize) + 1; + if branches.len() > max_addr { + return Err(ArrowError::SchemaError(format!( + "Reader union has {} branches, which exceeds the maximum addressable \ + branches by an Avro int tag ({} + 1).", + branches.len(), + i32::MAX + ))); + } Ok(Self { fields, type_ids: Vec::with_capacity(DEFAULT_CAPACITY), @@ -1275,13 +1315,22 @@ impl UnionDecoder { #[inline] fn read_tag(buf: &mut AvroCursor<'_>) -> Result { - let tag = buf.get_long()?; - if tag < 0 { + // Avro unions are encoded by first writing the zero-based branch index. + // In Avro 1.11.1 this is specified as an *int*; older specs said *long*, + // but both use zig-zag varint encoding, so decoding as long is compatible + // with either form and widely used in practice. + let raw = buf.get_long()?; + if raw < 0 { return Err(ArrowError::ParseError(format!( - "Negative union branch index {tag}" + "Negative union branch index {raw}" ))); } - Ok(tag as usize) + usize::try_from(raw).map_err(|_| { + ArrowError::ParseError(format!( + "Union branch index {raw} does not fit into usize on this platform ({}-bit)", + (usize::BITS as usize) + )) + }) } #[inline] @@ -1780,12 +1829,23 @@ impl Skipper { ), Codec::Map(values) => Self::Map(Box::new(Skipper::from_avro(values)?)), Codec::Interval => Self::DurationFixed12, - Codec::Union(encodings, _, _) => Self::Union( - encodings - .iter() - .map(Skipper::from_avro) - .collect::>()?, - ), + Codec::Union(encodings, _, _) => { + let max_addr = (i32::MAX as usize) + 1; + if encodings.len() > max_addr { + return Err(ArrowError::SchemaError(format!( + "Writer union has {} branches, which exceeds the maximum addressable \ + branches by an Avro int tag ({} + 1).", + encodings.len(), + i32::MAX + ))); + } + Self::Union( + encodings + .iter() + .map(Skipper::from_avro) + .collect::>()?, + ) + } _ => { return Err(ArrowError::NotYetImplemented(format!( "Skipper not implemented for codec {:?}", @@ -1865,13 +1925,19 @@ impl Skipper { } Self::Union(encodings) => { // Union tag must be ZigZag-decoded - let idx = buf.get_long()?; - if idx < 0 { + let raw = buf.get_long()?; + if raw < 0 { return Err(ArrowError::ParseError(format!( - "Negative union branch index {idx}" + "Negative union branch index {raw}" ))); } - let Some(encoding) = encodings.get_mut(idx as usize) else { + let idx: usize = usize::try_from(raw).map_err(|_| { + ArrowError::ParseError(format!( + "Union branch index {raw} does not fit into usize on this platform ({}-bit)", + (usize::BITS as usize) + )) + })?; + let Some(encoding) = encodings.get_mut(idx) else { return Err(ArrowError::ParseError(format!( "Union branch index {idx} out of range for skipper ({} branches)", encodings.len()