diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index a10e3a238d3c..89a66ddbaa85 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -16,13 +16,14 @@ // under the License. use crate::schema::{ - Attributes, AvroSchema, ComplexType, PrimitiveType, Record, Schema, TypeName, + Attributes, AvroSchema, ComplexType, PrimitiveType, Record, Schema, Type, TypeName, AVRO_ENUM_SYMBOLS_METADATA_KEY, }; use arrow_schema::{ ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, }; +use serde_json::Value; use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; @@ -32,7 +33,7 @@ use std::sync::Arc; /// /// To accommodate this we special case two-variant unions where one of the /// variants is the null type, and use this to derive arrow's notion of nullability -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq)] pub enum Nullability { /// The nulls are encoded as the first union variant NullFirst, @@ -40,6 +41,95 @@ pub enum Nullability { NullSecond, } +/// Contains information about how to resolve differences between a writer's and a reader's schema. +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum ResolutionInfo { + /// Indicates that the writer's type should be promoted to the reader's type. + Promotion(Promotion), + /// Indicates that a default value should be used for a field. (Implemented in a Follow-up PR) + DefaultValue(AvroLiteral), + /// Provides mapping information for resolving enums. (Implemented in a Follow-up PR) + EnumMapping(EnumMapping), + /// Provides resolution information for record fields. (Implemented in a Follow-up PR) + Record(ResolvedRecord), +} + +/// Represents a literal Avro value. +/// +/// This is used to represent default values in an Avro schema. +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum AvroLiteral { + /// Represents a null value. + Null, + /// Represents a boolean value. + Boolean(bool), + /// Represents an integer value. + Int(i32), + /// Represents a long value. + Long(i64), + /// Represents a float value. + Float(f32), + /// Represents a double value. + Double(f64), + /// Represents a bytes value. + Bytes(Vec), + /// Represents a string value. + String(String), + /// Represents an enum symbol. + Enum(String), + /// Represents an unsupported literal type. + Unsupported, +} + +/// Contains the necessary information to resolve a writer's record against a reader's record schema. +#[derive(Debug, Clone, PartialEq)] +pub struct ResolvedRecord { + /// Maps a writer's field index to the corresponding reader's field index. + /// `None` if the writer's field is not present in the reader's schema. + pub(crate) writer_to_reader: Arc<[Option]>, + /// A list of indices in the reader's schema for fields that have a default value. + pub(crate) default_fields: Arc<[usize]>, + /// For fields present in the writer's schema but not the reader's, this stores their data type. + /// This is needed to correctly skip over these fields during deserialization. + pub(crate) skip_fields: Arc<[Option]>, +} + +/// Defines the type of promotion to be applied during schema resolution. +/// +/// Schema resolution may require promoting a writer's data type to a reader's data type. +/// For example, an `int` can be promoted to a `long`, `float`, or `double`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum Promotion { + /// Promotes an `int` to a `long`. + IntToLong, + /// Promotes an `int` to a `float`. + IntToFloat, + /// Promotes an `int` to a `double`. + IntToDouble, + /// Promotes a `long` to a `float`. + LongToFloat, + /// Promotes a `long` to a `double`. + LongToDouble, + /// Promotes a `float` to a `double`. + FloatToDouble, + /// Promotes a `string` to `bytes`. + StringToBytes, + /// Promotes `bytes` to a `string`. + BytesToString, +} + +/// Holds the mapping information for resolving Avro enums. +/// +/// When resolving schemas, the writer's enum symbols must be mapped to the reader's symbols. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EnumMapping { + /// A mapping from the writer's symbol index to the reader's symbol index. + pub(crate) mapping: Arc<[i32]>, + /// The index to use for a writer's symbol that is not present in the reader's enum + /// and a default value is specified in the reader's schema. + pub(crate) default_index: i32, +} + #[cfg(feature = "canonical_extension_types")] fn with_extension_type(codec: &Codec, field: Field) -> Field { match codec { @@ -49,11 +139,12 @@ fn with_extension_type(codec: &Codec, field: Field) -> Field { } /// An Avro datatype mapped to the arrow data model -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct AvroDataType { nullability: Option, metadata: HashMap, codec: Codec, + pub(crate) resolution: Option, } impl AvroDataType { @@ -67,6 +158,22 @@ impl AvroDataType { codec, metadata, nullability, + resolution: None, + } + } + + #[inline] + fn new_with_resolution( + codec: Codec, + metadata: HashMap, + nullability: Option, + resolution: Option, + ) -> Self { + Self { + codec, + metadata, + nullability, + resolution, } } @@ -102,7 +209,7 @@ impl AvroDataType { } /// A named [`AvroDataType`] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct AvroField { name: String, data_type: AvroDataType, @@ -154,9 +261,16 @@ impl AvroField { use_utf8view: bool, strict_mode: bool, ) -> Result { - Err(ArrowError::NotYetImplemented( - "Resolving schema from a writer and reader schema is not yet implemented".to_string(), - )) + let top_name = match reader_schema { + Schema::Complex(ComplexType::Record(r)) => r.name.to_string(), + _ => "root".to_string(), + }; + let mut resolver = Maker::new(use_utf8view, strict_mode); + let data_type = resolver.make_data_type(writer_schema, Some(reader_schema), None)?; + Ok(Self { + name: top_name, + data_type, + }) } } @@ -166,8 +280,8 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { fn try_from(schema: &Schema<'a>) -> Result { match schema { Schema::Complex(ComplexType::Record(r)) => { - let mut resolver = Resolver::default(); - let data_type = make_data_type(schema, None, &mut resolver, false, false)?; + let mut resolver = Maker::new(false, false); + let data_type = resolver.make_data_type(schema, None, None)?; Ok(AvroField { data_type, name: r.name.to_string(), @@ -184,7 +298,7 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { #[derive(Debug)] pub struct AvroFieldBuilder<'a> { writer_schema: &'a Schema<'a>, - reader_schema: Option, + reader_schema: Option<&'a Schema<'a>>, use_utf8view: bool, strict_mode: bool, } @@ -205,7 +319,7 @@ impl<'a> AvroFieldBuilder<'a> { /// If a reader schema is provided, the builder will produce a resolved `AvroField` /// that can handle differences between the writer's and reader's schemas. #[inline] - pub fn with_reader_schema(mut self, reader_schema: AvroSchema) -> Self { + pub fn with_reader_schema(mut self, reader_schema: &'a Schema<'a>) -> Self { self.reader_schema = Some(reader_schema); self } @@ -226,14 +340,9 @@ impl<'a> AvroFieldBuilder<'a> { pub fn build(self) -> Result { match self.writer_schema { Schema::Complex(ComplexType::Record(r)) => { - let mut resolver = Resolver::default(); - let data_type = make_data_type( - self.writer_schema, - None, - &mut resolver, - self.use_utf8view, - self.strict_mode, - )?; + let mut resolver = Maker::new(self.use_utf8view, self.strict_mode); + let data_type = + resolver.make_data_type(self.writer_schema, self.reader_schema, None)?; Ok(AvroField { name: r.name.to_string(), data_type, @@ -250,7 +359,7 @@ impl<'a> AvroFieldBuilder<'a> { /// An Avro encoding /// /// -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum Codec { /// Represents Avro null type, maps to Arrow's Null data type Null, @@ -478,221 +587,417 @@ impl<'a> Resolver<'a> { } } -/// Parses a [`AvroDataType`] from the provided `schema` and the given `name` and `namespace` -/// -/// `name`: is name used to refer to `schema` in its parent -/// `namespace`: an optional qualifier used as part of a type hierarchy -/// If the data type is a string, convert to use Utf8View if requested -/// -/// This function is used during the schema conversion process to determine whether -/// string data should be represented as StringArray (default) or StringViewArray. -/// -/// `use_utf8view`: if true, use Utf8View instead of Utf8 for string types +/// Resolves Avro type names to [`AvroDataType`] /// -/// See [`Resolver`] for more information -fn make_data_type<'a>( - schema: &Schema<'a>, - namespace: Option<&'a str>, - resolver: &mut Resolver<'a>, +/// See +struct Maker<'a> { + resolver: Resolver<'a>, use_utf8view: bool, strict_mode: bool, -) -> Result { - match schema { - Schema::TypeName(TypeName::Primitive(p)) => { - let codec: Codec = (*p).into(); - let codec = codec.with_utf8view(use_utf8view); - Ok(AvroDataType { - nullability: None, - metadata: Default::default(), - codec, - }) +} + +impl<'a> Maker<'a> { + fn new(use_utf8view: bool, strict_mode: bool) -> Self { + Self { + resolver: Default::default(), + use_utf8view, + strict_mode, } - Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), - Schema::Union(f) => { - // Special case the common case of nullable primitives - let null = f - .iter() - .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); - match (f.len() == 2, null) { - (true, Some(0)) => { - let mut field = - make_data_type(&f[1], namespace, resolver, use_utf8view, strict_mode)?; - field.nullability = Some(Nullability::NullFirst); - Ok(field) - } - (true, Some(1)) => { - if strict_mode { - return Err(ArrowError::SchemaError( - "Found Avro union of the form ['T','null'], which is disallowed in strict_mode" - .to_string(), - )); + } + fn make_data_type<'s>( + &mut self, + writer_schema: &'s Schema<'a>, + reader_schema: Option<&'s Schema<'a>>, + namespace: Option<&'a str>, + ) -> Result { + match reader_schema { + Some(reader_schema) => self.resolve_type(writer_schema, reader_schema, namespace), + None => self.parse_type(writer_schema, namespace), + } + } + + /// Parses a [`AvroDataType`] from the provided [`Schema`] and the given `name` and `namespace` + /// + /// `name`: is the name used to refer to `schema` in its parent + /// `namespace`: an optional qualifier used as part of a type hierarchy + /// If the data type is a string, convert to use Utf8View if requested + /// + /// This function is used during the schema conversion process to determine whether + /// string data should be represented as StringArray (default) or StringViewArray. + /// + /// `use_utf8view`: if true, use Utf8View instead of Utf8 for string types + /// + /// See [`Resolver`] for more information + fn parse_type<'s>( + &mut self, + schema: &'s Schema<'a>, + namespace: Option<&'a str>, + ) -> Result { + match schema { + Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType::new( + Codec::from(*p).with_utf8view(self.use_utf8view), + Default::default(), + None, + )), + Schema::TypeName(TypeName::Ref(name)) => self.resolver.resolve(name, namespace), + Schema::Union(f) => { + // Special case the common case of nullable primitives + let null = f + .iter() + .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); + match (f.len() == 2, null) { + (true, Some(0)) => { + let mut field = self.parse_type(&f[1], namespace)?; + field.nullability = Some(Nullability::NullFirst); + Ok(field) } - let mut field = - make_data_type(&f[0], namespace, resolver, use_utf8view, strict_mode)?; - field.nullability = Some(Nullability::NullSecond); - Ok(field) + (true, Some(1)) => { + if self.strict_mode { + return Err(ArrowError::SchemaError( + "Found Avro union of the form ['T','null'], which is disallowed in strict_mode" + .to_string(), + )); + } + let mut field = self.parse_type(&f[0], namespace)?; + field.nullability = Some(Nullability::NullSecond); + Ok(field) + } + _ => Err(ArrowError::NotYetImplemented(format!( + "Union of {f:?} not currently supported" + ))), } - _ => Err(ArrowError::NotYetImplemented(format!( - "Union of {f:?} not currently supported" - ))), } - } - Schema::Complex(c) => match c { - ComplexType::Record(r) => { - let namespace = r.namespace.or(namespace); - let fields = r - .fields - .iter() - .map(|field| { - Ok(AvroField { - name: field.name.to_string(), - data_type: make_data_type( - &field.r#type, - namespace, - resolver, - use_utf8view, - strict_mode, - )?, + Schema::Complex(c) => match c { + ComplexType::Record(r) => { + let namespace = r.namespace.or(namespace); + let fields = r + .fields + .iter() + .map(|field| { + Ok(AvroField { + name: field.name.to_string(), + data_type: self.parse_type(&field.r#type, namespace)?, + }) }) + .collect::>()?; + let field = AvroDataType { + nullability: None, + codec: Codec::Struct(fields), + metadata: r.attributes.field_metadata(), + resolution: None, + }; + self.resolver.register(r.name, namespace, field.clone()); + Ok(field) + } + ComplexType::Array(a) => { + let mut field = self.parse_type(a.items.as_ref(), namespace)?; + Ok(AvroDataType { + nullability: None, + metadata: a.attributes.field_metadata(), + codec: Codec::List(Arc::new(field)), + resolution: None, }) - .collect::>()?; - let field = AvroDataType { - nullability: None, - codec: Codec::Struct(fields), - metadata: r.attributes.field_metadata(), - }; - resolver.register(r.name, namespace, field.clone()); - Ok(field) - } - ComplexType::Array(a) => { - let mut field = make_data_type( - a.items.as_ref(), - namespace, - resolver, - use_utf8view, - strict_mode, - )?; - Ok(AvroDataType { - nullability: None, - metadata: a.attributes.field_metadata(), - codec: Codec::List(Arc::new(field)), - }) - } - ComplexType::Fixed(f) => { - let size = f.size.try_into().map_err(|e| { - ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) - })?; - let md = f.attributes.field_metadata(); - let field = match f.attributes.logical_type { - Some("decimal") => { - let (precision, scale, _) = - parse_decimal_attributes(&f.attributes, Some(size as usize), true)?; - AvroDataType { - nullability: None, - metadata: md, - codec: Codec::Decimal(precision, Some(scale), Some(size as usize)), + } + ComplexType::Fixed(f) => { + let size = f.size.try_into().map_err(|e| { + ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) + })?; + let md = f.attributes.field_metadata(); + let field = match f.attributes.logical_type { + Some("decimal") => { + let (precision, scale, _) = + parse_decimal_attributes(&f.attributes, Some(size as usize), true)?; + AvroDataType { + nullability: None, + metadata: md, + codec: Codec::Decimal(precision, Some(scale), Some(size as usize)), + resolution: None, + } } - } - Some("duration") => { - if size != 12 { - return Err(ArrowError::ParseError(format!( - "Invalid fixed size for Duration: {size}, must be 12" - ))); - }; - AvroDataType { + Some("duration") => { + if size != 12 { + return Err(ArrowError::ParseError(format!( + "Invalid fixed size for Duration: {size}, must be 12" + ))); + }; + AvroDataType { + nullability: None, + metadata: md, + codec: Codec::Interval, + resolution: None, + } + } + _ => AvroDataType { nullability: None, metadata: md, - codec: Codec::Interval, - } - } - _ => AvroDataType { + codec: Codec::Fixed(size), + resolution: None, + }, + }; + self.resolver.register(f.name, namespace, field.clone()); + Ok(field) + } + ComplexType::Enum(e) => { + let namespace = e.namespace.or(namespace); + let symbols = e + .symbols + .iter() + .map(|s| s.to_string()) + .collect::>(); + + let mut metadata = e.attributes.field_metadata(); + let symbols_json = serde_json::to_string(&e.symbols).map_err(|e| { + ArrowError::ParseError(format!("Failed to serialize enum symbols: {e}")) + })?; + metadata.insert(AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), symbols_json); + let field = AvroDataType { nullability: None, - metadata: md, - codec: Codec::Fixed(size), - }, - }; - resolver.register(f.name, namespace, field.clone()); - Ok(field) - } - ComplexType::Enum(e) => { - let namespace = e.namespace.or(namespace); - let symbols = e - .symbols - .iter() - .map(|s| s.to_string()) - .collect::>(); - - let mut metadata = e.attributes.field_metadata(); - let symbols_json = serde_json::to_string(&e.symbols).map_err(|e| { - ArrowError::ParseError(format!("Failed to serialize enum symbols: {e}")) - })?; - metadata.insert(AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), symbols_json); - let field = AvroDataType { - nullability: None, - metadata, - codec: Codec::Enum(symbols), - }; - resolver.register(e.name, namespace, field.clone()); - Ok(field) - } - ComplexType::Map(m) => { - let val = - make_data_type(&m.values, namespace, resolver, use_utf8view, strict_mode)?; - Ok(AvroDataType { - nullability: None, - metadata: m.attributes.field_metadata(), - codec: Codec::Map(Arc::new(val)), - }) - } - }, - Schema::Type(t) => { - let mut field = make_data_type( - &Schema::TypeName(t.r#type.clone()), - namespace, - resolver, - use_utf8view, - strict_mode, - )?; - - // https://avro.apache.org/docs/1.11.1/specification/#logical-types - match (t.attributes.logical_type, &mut field.codec) { - (Some("decimal"), c @ Codec::Binary) => { - let (prec, sc, _) = parse_decimal_attributes(&t.attributes, None, false)?; - *c = Codec::Decimal(prec, Some(sc), None); + metadata, + codec: Codec::Enum(symbols), + resolution: None, + }; + self.resolver.register(e.name, namespace, field.clone()); + Ok(field) } - (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, - (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, - (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, - (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true), - (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true), - (Some("local-timestamp-millis"), c @ Codec::Int64) => { - *c = Codec::TimestampMillis(false) + ComplexType::Map(m) => { + let val = self.parse_type(&m.values, namespace)?; + Ok(AvroDataType { + nullability: None, + metadata: m.attributes.field_metadata(), + codec: Codec::Map(Arc::new(val)), + resolution: None, + }) } - (Some("local-timestamp-micros"), c @ Codec::Int64) => { - *c = Codec::TimestampMicros(false) + }, + Schema::Type(t) => { + let mut field = self.parse_type(&Schema::TypeName(t.r#type.clone()), namespace)?; + // https://avro.apache.org/docs/1.11.1/specification/#logical-types + match (t.attributes.logical_type, &mut field.codec) { + (Some("decimal"), c @ Codec::Binary) => { + let (prec, sc, _) = parse_decimal_attributes(&t.attributes, None, false)?; + *c = Codec::Decimal(prec, Some(sc), None); + } + (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, + (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, + (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, + (Some("timestamp-millis"), c @ Codec::Int64) => { + *c = Codec::TimestampMillis(true) + } + (Some("timestamp-micros"), c @ Codec::Int64) => { + *c = Codec::TimestampMicros(true) + } + (Some("local-timestamp-millis"), c @ Codec::Int64) => { + *c = Codec::TimestampMillis(false) + } + (Some("local-timestamp-micros"), c @ Codec::Int64) => { + *c = Codec::TimestampMicros(false) + } + (Some("uuid"), c @ Codec::Utf8) => *c = Codec::Uuid, + (Some(logical), _) => { + // Insert unrecognized logical type into metadata map + field.metadata.insert("logicalType".into(), logical.into()); + } + (None, _) => {} } - (Some("uuid"), c @ Codec::Utf8) => *c = Codec::Uuid, - (Some(logical), _) => { - // Insert unrecognized logical type into metadata map - field.metadata.insert("logicalType".into(), logical.into()); + if !t.attributes.additional.is_empty() { + for (k, v) in &t.attributes.additional { + field.metadata.insert(k.to_string(), v.to_string()); + } } - (None, _) => {} + Ok(field) } + } + } - if !t.attributes.additional.is_empty() { - for (k, v) in &t.attributes.additional { - field.metadata.insert(k.to_string(), v.to_string()); - } + fn resolve_type<'s>( + &mut self, + writer_schema: &'s Schema<'a>, + reader_schema: &'s Schema<'a>, + namespace: Option<&'a str>, + ) -> Result { + match (writer_schema, reader_schema) { + ( + Schema::TypeName(TypeName::Primitive(writer_primitive)), + Schema::TypeName(TypeName::Primitive(reader_primitive)), + ) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema), + ( + Schema::Type(Type { + r#type: TypeName::Primitive(writer_primitive), + .. + }), + Schema::Type(Type { + r#type: TypeName::Primitive(reader_primitive), + .. + }), + ) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema), + ( + Schema::TypeName(TypeName::Primitive(writer_primitive)), + Schema::Type(Type { + r#type: TypeName::Primitive(reader_primitive), + .. + }), + ) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema), + ( + Schema::Type(Type { + r#type: TypeName::Primitive(writer_primitive), + .. + }), + Schema::TypeName(TypeName::Primitive(reader_primitive)), + ) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema), + ( + Schema::Complex(ComplexType::Record(writer_record)), + Schema::Complex(ComplexType::Record(reader_record)), + ) => self.resolve_records(writer_record, reader_record, namespace), + (Schema::Union(writer_variants), Schema::Union(reader_variants)) => { + self.resolve_nullable_union(writer_variants, reader_variants, namespace) + } + _ => Err(ArrowError::NotYetImplemented( + "Other resolutions not yet implemented".to_string(), + )), + } + } + + fn resolve_primitives( + &mut self, + write_primitive: PrimitiveType, + read_primitive: PrimitiveType, + reader_schema: &Schema<'a>, + ) -> Result { + if write_primitive == read_primitive { + return self.parse_type(reader_schema, None); + } + let promotion = match (write_primitive, read_primitive) { + (PrimitiveType::Int, PrimitiveType::Long) => Promotion::IntToLong, + (PrimitiveType::Int, PrimitiveType::Float) => Promotion::IntToFloat, + (PrimitiveType::Int, PrimitiveType::Double) => Promotion::IntToDouble, + (PrimitiveType::Long, PrimitiveType::Float) => Promotion::LongToFloat, + (PrimitiveType::Long, PrimitiveType::Double) => Promotion::LongToDouble, + (PrimitiveType::Float, PrimitiveType::Double) => Promotion::FloatToDouble, + (PrimitiveType::String, PrimitiveType::Bytes) => Promotion::StringToBytes, + (PrimitiveType::Bytes, PrimitiveType::String) => Promotion::BytesToString, + _ => { + return Err(ArrowError::ParseError(format!( + "Illegal promotion {write_primitive:?} to {read_primitive:?}" + ))) } - Ok(field) + }; + let mut datatype = self.parse_type(reader_schema, None)?; + datatype.resolution = Some(ResolutionInfo::Promotion(promotion)); + Ok(datatype) + } + + fn resolve_nullable_union( + &mut self, + writer_variants: &[Schema<'a>], + reader_variants: &[Schema<'a>], + namespace: Option<&'a str>, + ) -> Result { + // Only support unions with exactly two branches, one of which is `null` on both sides + if writer_variants.len() != 2 || reader_variants.len() != 2 { + return Err(ArrowError::NotYetImplemented( + "Only 2-branch unions are supported for schema resolution".to_string(), + )); } + let is_null = |s: &Schema<'a>| { + matches!( + s, + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)) + ) + }; + let w_null_pos = writer_variants.iter().position(is_null); + let r_null_pos = reader_variants.iter().position(is_null); + match (w_null_pos, r_null_pos) { + (Some(wp), Some(rp)) => { + // Extract a non-null branch on each side + let w_nonnull = &writer_variants[1 - wp]; + let r_nonnull = &reader_variants[1 - rp]; + // Resolve the non-null branch + let mut dt = self.make_data_type(w_nonnull, Some(r_nonnull), namespace)?; + // Adopt reader union null ordering + dt.nullability = Some(match rp { + 0 => Nullability::NullFirst, + 1 => Nullability::NullSecond, + _ => unreachable!(), + }); + Ok(dt) + } + _ => Err(ArrowError::NotYetImplemented( + "Union resolution requires both writer and reader to be nullable unions" + .to_string(), + )), + } + } + + fn resolve_records( + &mut self, + writer_record: &Record<'a>, + reader_record: &Record<'a>, + namespace: Option<&'a str>, + ) -> Result { + // Names must match or be aliased + let names_match = writer_record.name == reader_record.name + || reader_record.aliases.contains(&writer_record.name) + || writer_record.aliases.contains(&reader_record.name); + if !names_match { + return Err(ArrowError::ParseError(format!( + "Record name mismatch writer={}, reader={}", + writer_record.name, reader_record.name + ))); + } + let writer_ns = writer_record.namespace.or(namespace); + let reader_ns = reader_record.namespace.or(namespace); + // Map writer field name -> index + let mut writer_index_map = + HashMap::<&str, usize>::with_capacity(writer_record.fields.len()); + for (idx, write_field) in writer_record.fields.iter().enumerate() { + writer_index_map.insert(write_field.name, idx); + } + // Prepare outputs + let mut reader_fields: Vec = Vec::with_capacity(reader_record.fields.len()); + let mut writer_to_reader: Vec> = vec![None; writer_record.fields.len()]; + //let mut skip_fields: Vec> = vec![None; writer_record.fields.len()]; + //let mut default_fields: Vec = Vec::new(); + // Build reader fields and mapping + for (reader_idx, r_field) in reader_record.fields.iter().enumerate() { + if let Some(&writer_idx) = writer_index_map.get(r_field.name) { + // Field exists in writer: resolve types (including promotions and union-of-null) + let w_schema = &writer_record.fields[writer_idx].r#type; + let resolved_dt = + self.make_data_type(w_schema, Some(&r_field.r#type), reader_ns)?; + reader_fields.push(AvroField { + name: r_field.name.to_string(), + data_type: resolved_dt, + }); + writer_to_reader[writer_idx] = Some(reader_idx); + } else { + return Err(ArrowError::NotYetImplemented( + "New fields from reader with default values not yet implemented".to_string(), + )); + } + } + // Implement writer-only fields to skip in Follow-up PR here + // Build resolved record AvroDataType + let resolved = AvroDataType::new_with_resolution( + Codec::Struct(Arc::from(reader_fields)), + reader_record.attributes.field_metadata(), + None, + Some(ResolutionInfo::Record(ResolvedRecord { + writer_to_reader: Arc::from(writer_to_reader), + default_fields: Arc::default(), + skip_fields: Arc::default(), + })), + ); + // Register a resolved record by reader name+namespace for potential named type refs + self.resolver + .register(reader_record.name, reader_ns, resolved.clone()); + Ok(resolved) } } #[cfg(test)] mod tests { use super::*; - use crate::schema::{Attributes, PrimitiveType, Schema, Type, TypeName}; + use crate::schema::{Attributes, Fixed, PrimitiveType, Schema, Type, TypeName}; use serde_json; fn create_schema_with_logical_type( @@ -710,12 +1015,36 @@ mod tests { }) } + fn create_fixed_schema(size: usize, logical_type: &'static str) -> Schema<'static> { + let attributes = Attributes { + logical_type: Some(logical_type), + additional: Default::default(), + }; + + Schema::Complex(ComplexType::Fixed(Fixed { + name: "fixed_type", + namespace: None, + aliases: Vec::new(), + size, + attributes, + })) + } + + fn resolve_promotion(writer: PrimitiveType, reader: PrimitiveType) -> AvroDataType { + let writer_schema = Schema::TypeName(TypeName::Primitive(writer)); + let reader_schema = Schema::TypeName(TypeName::Primitive(reader)); + let mut maker = Maker::new(false, false); + maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .expect("promotion should resolve") + } + #[test] fn test_date_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Int, "date"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::Date32)); } @@ -724,8 +1053,8 @@ mod tests { fn test_time_millis_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Int, "time-millis"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimeMillis)); } @@ -734,8 +1063,8 @@ mod tests { fn test_time_micros_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "time-micros"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimeMicros)); } @@ -744,8 +1073,8 @@ mod tests { fn test_timestamp_millis_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-millis"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimestampMillis(true))); } @@ -754,8 +1083,8 @@ mod tests { fn test_timestamp_micros_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-micros"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimestampMicros(true))); } @@ -764,8 +1093,8 @@ mod tests { fn test_local_timestamp_millis_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-millis"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimestampMillis(false))); } @@ -774,8 +1103,8 @@ mod tests { fn test_local_timestamp_micros_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-micros"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimestampMicros(false))); } @@ -822,13 +1151,12 @@ mod tests { panic!("Expected NotYetImplemented error"); } } - #[test] fn test_unknown_logical_type_added_to_metadata() { let schema = create_schema_with_logical_type(PrimitiveType::Int, "custom-type"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert_eq!( result.metadata.get("logicalType"), @@ -840,8 +1168,8 @@ mod tests { fn test_string_with_utf8view_enabled() { let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String)); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, true, false).unwrap(); + let mut maker = Maker::new(true, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::Utf8View)); } @@ -850,8 +1178,8 @@ mod tests { fn test_string_without_utf8view_enabled() { let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String)); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::Utf8)); } @@ -878,8 +1206,8 @@ mod tests { let schema = Schema::Complex(ComplexType::Record(record)); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, true, false).unwrap(); + let mut maker = Maker::new(true, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); if let Codec::Struct(fields) = &result.codec { let first_field_codec = &fields[0].data_type().codec; @@ -896,8 +1224,8 @@ mod tests { Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), ]); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false, true); + let mut maker = Maker::new(false, true); + let result = maker.make_data_type(&schema, None, None); assert!(result.is_err()); match result { @@ -910,6 +1238,126 @@ mod tests { } } + #[test] + fn test_resolve_int_to_float_promotion() { + let result = resolve_promotion(PrimitiveType::Int, PrimitiveType::Float); + assert!(matches!(result.codec, Codec::Float32)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToFloat)) + ); + } + + #[test] + fn test_resolve_int_to_double_promotion() { + let result = resolve_promotion(PrimitiveType::Int, PrimitiveType::Double); + assert!(matches!(result.codec, Codec::Float64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToDouble)) + ); + } + + #[test] + fn test_resolve_long_to_float_promotion() { + let result = resolve_promotion(PrimitiveType::Long, PrimitiveType::Float); + assert!(matches!(result.codec, Codec::Float32)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::LongToFloat)) + ); + } + + #[test] + fn test_resolve_long_to_double_promotion() { + let result = resolve_promotion(PrimitiveType::Long, PrimitiveType::Double); + assert!(matches!(result.codec, Codec::Float64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::LongToDouble)) + ); + } + + #[test] + fn test_resolve_float_to_double_promotion() { + let result = resolve_promotion(PrimitiveType::Float, PrimitiveType::Double); + assert!(matches!(result.codec, Codec::Float64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::FloatToDouble)) + ); + } + + #[test] + fn test_resolve_string_to_bytes_promotion() { + let result = resolve_promotion(PrimitiveType::String, PrimitiveType::Bytes); + assert!(matches!(result.codec, Codec::Binary)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::StringToBytes)) + ); + } + + #[test] + fn test_resolve_bytes_to_string_promotion() { + let result = resolve_promotion(PrimitiveType::Bytes, PrimitiveType::String); + assert!(matches!(result.codec, Codec::Utf8)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::BytesToString)) + ); + } + + #[test] + fn test_resolve_illegal_promotion_double_to_float_errors() { + let writer_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::Double)); + let reader_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::Float)); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&writer_schema, Some(&reader_schema), None); + assert!(result.is_err()); + match result { + Err(ArrowError::ParseError(msg)) => { + assert!(msg.contains("Illegal promotion")); + } + _ => panic!("Expected ParseError for illegal promotion Double -> Float"), + } + } + + #[test] + fn test_promotion_within_nullable_union_keeps_reader_null_ordering() { + let writer = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + ]); + let reader = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Double)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + assert!(matches!(result.codec, Codec::Float64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToDouble)) + ); + assert_eq!(result.nullability, Some(Nullability::NullSecond)); + } + + #[test] + fn test_resolve_type_promotion() { + let writer_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + let reader_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)); + let mut maker = Maker::new(false, false); + let result = maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .unwrap(); + assert!(matches!(result.codec, Codec::Int64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToLong)) + ); + } + #[test] fn test_nested_record_type_reuse_without_namespace() { let schema_str = r#" @@ -936,8 +1384,8 @@ mod tests { let schema: Schema = serde_json::from_str(schema_str).unwrap(); - let mut resolver = Resolver::default(); - let avro_data_type = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let avro_data_type = maker.make_data_type(&schema, None, None).unwrap(); if let Codec::Struct(fields) = avro_data_type.codec() { assert_eq!(fields.len(), 4); @@ -1016,8 +1464,8 @@ mod tests { let schema: Schema = serde_json::from_str(schema_str).unwrap(); - let mut resolver = Resolver::default(); - let avro_data_type = make_data_type(&schema, None, &mut resolver, false, false).unwrap(); + let mut maker = Maker::new(false, false); + let avro_data_type = maker.make_data_type(&schema, None, None).unwrap(); if let Codec::Struct(fields) = avro_data_type.codec() { assert_eq!(fields.len(), 4); diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 7bbcaeb9f027..802a3df8b70b 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -372,11 +372,11 @@ impl ReaderBuilder { fn make_record_decoder( &self, writer_schema: &Schema, - reader_schema: Option<&AvroSchema>, + reader_schema: Option<&Schema>, ) -> Result { let mut builder = AvroFieldBuilder::new(writer_schema); if let Some(reader_schema) = reader_schema { - builder = builder.with_reader_schema(reader_schema.clone()); + builder = builder.with_reader_schema(reader_schema); } let root = builder .with_utf8view(self.utf8_view) @@ -385,6 +385,15 @@ impl ReaderBuilder { RecordDecoder::try_new_with_options(root.data_type(), self.utf8_view) } + fn make_record_decoder_from_schemas( + &self, + writer_schema: &Schema, + reader_schema: Option<&AvroSchema>, + ) -> Result { + let reader_schema_raw = reader_schema.map(|s| s.schema()).transpose()?; + self.make_record_decoder(writer_schema, reader_schema_raw.as_ref()) + } + fn make_decoder_with_parts( &self, active_decoder: RecordDecoder, @@ -418,7 +427,8 @@ impl ReaderBuilder { .ok_or_else(|| { ArrowError::ParseError("No Avro schema present in file header".into()) })?; - let record_decoder = self.make_record_decoder(&writer_schema, reader_schema)?; + let record_decoder = + self.make_record_decoder_from_schemas(&writer_schema, reader_schema)?; return Ok(self.make_decoder_with_parts( record_decoder, None, @@ -453,11 +463,12 @@ impl ReaderBuilder { } }; let writer_schema = avro_schema.schema()?; - let decoder = self.make_record_decoder(&writer_schema, reader_schema)?; + let record_decoder = + self.make_record_decoder_from_schemas(&writer_schema, reader_schema)?; if fingerprint == start_fingerprint { - active_decoder = Some(decoder); + active_decoder = Some(record_decoder); } else { - cache.insert(fingerprint, decoder); + cache.insert(fingerprint, record_decoder); } } let active_decoder = active_decoder.ok_or_else(|| { @@ -662,6 +673,7 @@ mod test { use bytes::{Buf, BufMut, Bytes}; use futures::executor::block_on; use futures::{stream, Stream, StreamExt, TryStreamExt}; + use serde_json::Value; use std::collections::HashMap; use std::fs; use std::fs::File; @@ -804,10 +816,10 @@ mod test { #[test] fn test_unknown_fingerprint_is_error() { - let (store, fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store(); + let (store, fp_int, _fp_long, _schema_int, schema_long) = make_two_schema_store(); let unknown_fp = Fingerprint::Rabin(0xDEAD_BEEF_DEAD_BEEF); let prefix = make_prefix(unknown_fp); - let mut decoder = make_decoder(&store, fp_int, &schema_int); + let mut decoder = make_decoder(&store, fp_int, &schema_long); let err = decoder.decode(&prefix).expect_err("decode should error"); let msg = err.to_string(); assert!( @@ -818,8 +830,8 @@ mod test { #[test] fn test_handle_prefix_incomplete_magic() { - let (store, fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store(); - let mut decoder = make_decoder(&store, fp_int, &schema_int); + let (store, fp_int, _fp_long, _schema_int, schema_long) = make_two_schema_store(); + let mut decoder = make_decoder(&store, fp_int, &schema_long); let buf = &SINGLE_OBJECT_MAGIC[..1]; let res = decoder.handle_prefix(buf).unwrap(); assert_eq!(res, Some(0)); @@ -828,8 +840,8 @@ mod test { #[test] fn test_handle_prefix_magic_mismatch() { - let (store, fp_int, _fp_long, schema_int, _schema_long) = make_two_schema_store(); - let mut decoder = make_decoder(&store, fp_int, &schema_int); + let (store, fp_int, _fp_long, _schema_int, schema_long) = make_two_schema_store(); + let mut decoder = make_decoder(&store, fp_int, &schema_long); let buf = [0xFFu8, 0x00u8, 0x01u8]; let res = decoder.handle_prefix(&buf).unwrap(); assert!(res.is_none()); @@ -837,8 +849,8 @@ mod test { #[test] fn test_handle_prefix_incomplete_fingerprint() { - let (store, fp_int, fp_long, schema_int, _schema_long) = make_two_schema_store(); - let mut decoder = make_decoder(&store, fp_int, &schema_int); + let (store, fp_int, fp_long, _schema_int, schema_long) = make_two_schema_store(); + let mut decoder = make_decoder(&store, fp_int, &schema_long); let long_bytes = match fp_long { Fingerprint::Rabin(v) => v.to_le_bytes(), }; @@ -851,8 +863,8 @@ mod test { #[test] fn test_handle_prefix_valid_prefix_switches_schema() { - let (store, fp_int, fp_long, schema_int, schema_long) = make_two_schema_store(); - let mut decoder = make_decoder(&store, fp_int, &schema_int); + let (store, fp_int, fp_long, _schema_int, schema_long) = make_two_schema_store(); + let mut decoder = make_decoder(&store, fp_int, &schema_long); let writer_schema_long = schema_long.schema().unwrap(); let root_long = AvroFieldBuilder::new(&writer_schema_long).build().unwrap(); let long_decoder =