diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 802a3df8b70b..3f2daff0a3b1 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -802,6 +802,411 @@ mod test { msg } + fn load_writer_schema_json(path: &str) -> Value { + let file = File::open(path).unwrap(); + let header = super::read_header(BufReader::new(file)).unwrap(); + let schema = header.schema().unwrap().unwrap(); + serde_json::to_value(&schema).unwrap() + } + + fn make_reader_schema_with_promotions( + path: &str, + promotions: &HashMap<&str, &str>, + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let fields = root + .get_mut("fields") + .and_then(|f| f.as_array_mut()) + .expect("record has fields"); + for f in fields.iter_mut() { + let Some(name) = f.get("name").and_then(|n| n.as_str()) else { + continue; + }; + if let Some(new_ty) = promotions.get(name) { + let ty = f.get_mut("type").expect("field has a type"); + match ty { + Value::String(_) => { + *ty = Value::String((*new_ty).to_string()); + } + // Union + Value::Array(arr) => { + for b in arr.iter_mut() { + match b { + Value::String(s) if s != "null" => { + *b = Value::String((*new_ty).to_string()); + break; + } + Value::Object(_) => { + *b = Value::String((*new_ty).to_string()); + break; + } + _ => {} + } + } + } + Value::Object(_) => { + *ty = Value::String((*new_ty).to_string()); + } + _ => {} + } + } + } + AvroSchema::new(root.to_string()) + } + + fn read_alltypes_with_reader_schema(path: &str, reader_schema: AvroSchema) -> RecordBatch { + let file = File::open(path).unwrap(); + let reader = ReaderBuilder::new() + .with_batch_size(1024) + .with_utf8_view(false) + .with_reader_schema(reader_schema) + .build(BufReader::new(file)) + .unwrap(); + + let schema = reader.schema(); + let batches = reader.collect::, _>>().unwrap(); + arrow::compute::concat_batches(&schema, &batches).unwrap() + } + + #[test] + fn test_alltypes_schema_promotion_mixed() { + let files = [ + "avro/alltypes_plain.avro", + "avro/alltypes_plain.snappy.avro", + "avro/alltypes_plain.zstandard.avro", + "avro/alltypes_plain.bzip2.avro", + "avro/alltypes_plain.xz.avro", + ]; + for file in files { + let file = arrow_test_data(file); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("id", "long"); + promotions.insert("tinyint_col", "float"); + promotions.insert("smallint_col", "double"); + promotions.insert("int_col", "double"); + promotions.insert("bigint_col", "double"); + promotions.insert("float_col", "double"); + promotions.insert("date_string_col", "string"); + promotions.insert("string_col", "string"); + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int64Array::from(vec![4i64, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32), + )) as _, + true, + ), + ( + "smallint_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64), + )) as _, + true, + ), + ( + "int_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64), + )) as _, + true, + ), + ( + "bigint_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| ((x % 2) * 10) as f64), + )) as _, + true, + ), + ( + "float_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| ((x % 2) as f32 * 1.1f32) as f64), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", + Arc::new(StringArray::from(vec![ + "03/01/09", "03/01/09", "04/01/09", "04/01/09", "02/01/09", "02/01/09", + "01/01/09", "01/01/09", + ])) as _, + true, + ), + ( + "string_col", + Arc::new(StringArray::from( + (0..8) + .map(|x| if x % 2 == 0 { "0" } else { "1" }) + .collect::>(), + )) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch, expected, "mismatch for file {file}"); + } + } + + #[test] + fn test_alltypes_schema_promotion_long_to_float_only() { + let files = [ + "avro/alltypes_plain.avro", + "avro/alltypes_plain.snappy.avro", + "avro/alltypes_plain.zstandard.avro", + "avro/alltypes_plain.bzip2.avro", + "avro/alltypes_plain.xz.avro", + ]; + for file in files { + let file = arrow_test_data(file); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("bigint_col", "float"); + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "int_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "bigint_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| ((x % 2) * 10) as f32), + )) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32 * 1.1), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", + Arc::new(BinaryArray::from_iter_values([ + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], + ])) as _, + true, + ), + ( + "string_col", + Arc::new(BinaryArray::from_iter_values((0..8).map(|x| [48 + x % 2]))) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch, expected, "mismatch for file {file}"); + } + } + + #[test] + fn test_alltypes_schema_promotion_bytes_to_string_only() { + let files = [ + "avro/alltypes_plain.avro", + "avro/alltypes_plain.snappy.avro", + "avro/alltypes_plain.zstandard.avro", + "avro/alltypes_plain.bzip2.avro", + "avro/alltypes_plain.xz.avro", + ]; + for file in files { + let file = arrow_test_data(file); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("date_string_col", "string"); + promotions.insert("string_col", "string"); + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "int_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "bigint_col", + Arc::new(Int64Array::from_iter_values((0..8).map(|x| (x % 2) * 10))) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32 * 1.1), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", + Arc::new(StringArray::from(vec![ + "03/01/09", "03/01/09", "04/01/09", "04/01/09", "02/01/09", "02/01/09", + "01/01/09", "01/01/09", + ])) as _, + true, + ), + ( + "string_col", + Arc::new(StringArray::from( + (0..8) + .map(|x| if x % 2 == 0 { "0" } else { "1" }) + .collect::>(), + )) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch, expected, "mismatch for file {file}"); + } + } + + #[test] + fn test_alltypes_illegal_promotion_bool_to_double_errors() { + let file = arrow_test_data("avro/alltypes_plain.avro"); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("bool_col", "double"); // illegal + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let file_handle = File::open(&file).unwrap(); + let result = ReaderBuilder::new() + .with_reader_schema(reader_schema) + .build(BufReader::new(file_handle)); + let err = result.expect_err("expected illegal promotion to error"); + let msg = err.to_string(); + assert!( + msg.contains("Illegal promotion") || msg.contains("illegal promotion"), + "unexpected error: {msg}" + ); + } + #[test] fn test_schema_store_register_lookup() { let schema_int = make_record_schema(PrimitiveType::Int); diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 180afcd2d8c3..a51e4c78740f 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::codec::{AvroDataType, Codec, Nullability}; +use crate::codec::{AvroDataType, Codec, Nullability, Promotion, ResolutionInfo}; use crate::reader::block::{Block, BlockDecoder}; use crate::reader::cursor::AvroCursor; use crate::reader::header::Header; @@ -154,6 +154,14 @@ enum Decoder { TimeMicros(Vec), TimestampMillis(bool, Vec), TimestampMicros(bool, Vec), + Int32ToInt64(Vec), + Int32ToFloat32(Vec), + Int32ToFloat64(Vec), + Int64ToFloat32(Vec), + Int64ToFloat64(Vec), + Float32ToFloat64(Vec), + BytesToString(OffsetBufferBuilder, Vec), + StringToBytes(OffsetBufferBuilder, Vec), Binary(OffsetBufferBuilder, Vec), /// String data encoded as UTF-8 bytes, mapped to Arrow's StringArray String(OffsetBufferBuilder, Vec), @@ -179,36 +187,68 @@ enum Decoder { impl Decoder { fn try_new(data_type: &AvroDataType) -> Result { - let decoder = match data_type.codec() { - Codec::Null => Self::Null(0), - Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), - Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Int64 => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Float32 => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Float64 => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Binary => Self::Binary( + // Extract just the Promotion (if any) to simplify pattern matching + let promotion = match data_type.resolution.as_ref() { + Some(ResolutionInfo::Promotion(p)) => Some(p), + _ => None, + }; + let decoder = match (data_type.codec(), promotion) { + (Codec::Int64, Some(Promotion::IntToLong)) => { + Self::Int32ToInt64(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float32, Some(Promotion::IntToFloat)) => { + Self::Int32ToFloat32(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float64, Some(Promotion::IntToDouble)) => { + Self::Int32ToFloat64(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float32, Some(Promotion::LongToFloat)) => { + Self::Int64ToFloat32(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float64, Some(Promotion::LongToDouble)) => { + Self::Int64ToFloat64(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float64, Some(Promotion::FloatToDouble)) => { + Self::Float32ToFloat64(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Utf8, Some(Promotion::BytesToString)) + | (Codec::Utf8View, Some(Promotion::BytesToString)) => Self::BytesToString( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Utf8 => Self::String( + (Codec::Binary, Some(Promotion::StringToBytes)) => Self::StringToBytes( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Utf8View => Self::StringView( + (Codec::Null, _) => Self::Null(0), + (Codec::Boolean, _) => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), + (Codec::Int32, _) => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Int64, _) => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Float32, _) => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Float64, _) => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Binary, _) => Self::Binary( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimestampMillis(is_utc) => { + (Codec::Utf8, _) => Self::String( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ), + (Codec::Utf8View, _) => Self::StringView( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ), + (Codec::Date32, _) => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::TimeMillis, _) => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::TimeMicros, _) => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::TimestampMillis(is_utc), _) => { Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::TimestampMicros(is_utc) => { + (Codec::TimestampMicros(is_utc), _) => { Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::Fixed(sz) => Self::Fixed(*sz, Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Decimal(precision, scale, size) => { + (Codec::Fixed(sz), _) => Self::Fixed(*sz, Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Decimal(precision, scale, size), _) => { let p = *precision; let s = *scale; let sz = *size; @@ -247,8 +287,8 @@ impl Decoder { } } } - Codec::Interval => Self::Duration(IntervalMonthDayNanoBuilder::new()), - Codec::List(item) => { + (Codec::Interval, _) => Self::Duration(IntervalMonthDayNanoBuilder::new()), + (Codec::List(item), _) => { let decoder = Self::try_new(item)?; Self::Array( Arc::new(item.field_with_name("item")), @@ -256,10 +296,10 @@ impl Decoder { Box::new(decoder), ) } - Codec::Enum(symbols) => { + (Codec::Enum(symbols), _) => { Self::Enum(Vec::with_capacity(DEFAULT_CAPACITY), symbols.clone()) } - Codec::Struct(fields) => { + (Codec::Struct(fields), _) => { let mut arrow_fields = Vec::with_capacity(fields.len()); let mut encodings = Vec::with_capacity(fields.len()); for avro_field in fields.iter() { @@ -269,7 +309,7 @@ impl Decoder { } Self::Record(arrow_fields.into(), encodings) } - Codec::Map(child) => { + (Codec::Map(child), _) => { let val_field = child.field_with_name("value").with_nullable(true); let map_field = Arc::new(ArrowField::new( "entries", @@ -288,7 +328,7 @@ impl Decoder { Box::new(val_dec), ) } - Codec::Uuid => Self::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Uuid, _) => Self::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)), }; Ok(match data_type.nullability() { Some(nullability) => Self::Nullable( @@ -307,12 +347,20 @@ impl Decoder { Self::Boolean(b) => b.append(false), Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0), Self::Int64(v) + | Self::Int32ToInt64(v) | Self::TimeMicros(v) | Self::TimestampMillis(_, v) | Self::TimestampMicros(_, v) => v.push(0), - Self::Float32(v) => v.push(0.), - Self::Float64(v) => v.push(0.), - Self::Binary(offsets, _) | Self::String(offsets, _) | Self::StringView(offsets, _) => { + Self::Float32(v) | Self::Int32ToFloat32(v) | Self::Int64ToFloat32(v) => v.push(0.), + Self::Float64(v) + | Self::Int32ToFloat64(v) + | Self::Int64ToFloat64(v) + | Self::Float32ToFloat64(v) => v.push(0.), + Self::Binary(offsets, _) + | Self::String(offsets, _) + | Self::StringView(offsets, _) + | Self::BytesToString(offsets, _) + | Self::StringToBytes(offsets, _) => { offsets.push_length(0); } Self::Uuid(v) => { @@ -353,7 +401,15 @@ impl Decoder { | Self::TimestampMicros(_, values) => values.push(buf.get_long()?), Self::Float32(values) => values.push(buf.get_float()?), Self::Float64(values) => values.push(buf.get_double()?), - Self::Binary(offsets, values) + Self::Int32ToInt64(values) => values.push(buf.get_int()? as i64), + Self::Int32ToFloat32(values) => values.push(buf.get_int()? as f32), + Self::Int32ToFloat64(values) => values.push(buf.get_int()? as f64), + Self::Int64ToFloat32(values) => values.push(buf.get_long()? as f32), + Self::Int64ToFloat64(values) => values.push(buf.get_long()? as f64), + Self::Float32ToFloat64(values) => values.push(buf.get_float()? as f64), + Self::StringToBytes(offsets, values) + | Self::BytesToString(offsets, values) + | Self::Binary(offsets, values) | Self::String(offsets, values) | Self::StringView(offsets, values) => { let data = buf.get_bytes()?; @@ -464,12 +520,21 @@ impl Decoder { ), Self::Float32(values) => Arc::new(flush_primitive::(values, nulls)), Self::Float64(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Binary(offsets, values) => { + Self::Int32ToInt64(values) => Arc::new(flush_primitive::(values, nulls)), + Self::Int32ToFloat32(values) | Self::Int64ToFloat32(values) => { + Arc::new(flush_primitive::(values, nulls)) + } + Self::Int32ToFloat64(values) + | Self::Int64ToFloat64(values) + | Self::Float32ToFloat64(values) => { + Arc::new(flush_primitive::(values, nulls)) + } + Self::StringToBytes(offsets, values) | Self::Binary(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); Arc::new(BinaryArray::new(offsets, values, nulls)) } - Self::String(offsets, values) => { + Self::BytesToString(offsets, values) | Self::String(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); Arc::new(StringArray::new(offsets, values, nulls)) @@ -672,6 +737,7 @@ fn sign_extend_to(raw: &[u8]) -> Result<[u8; N], ArrowError> { #[cfg(test)] mod tests { use super::*; + use crate::codec::AvroField; use arrow_array::{ cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray, IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray, @@ -709,6 +775,185 @@ mod tests { AvroDataType::new(codec, Default::default(), None) } + fn decoder_for_promotion( + writer: PrimitiveType, + reader: PrimitiveType, + use_utf8view: bool, + ) -> Decoder { + let ws = Schema::TypeName(TypeName::Primitive(writer)); + let rs = Schema::TypeName(TypeName::Primitive(reader)); + let field = + AvroField::resolve_from_writer_and_reader(&ws, &rs, use_utf8view, false).unwrap(); + Decoder::try_new(field.data_type()).unwrap() + } + + #[test] + fn test_schema_resolution_promotion_int_to_long() { + let mut dec = decoder_for_promotion(PrimitiveType::Int, PrimitiveType::Long, false); + assert!(matches!(dec, Decoder::Int32ToInt64(_))); + for v in [0, 1, -2, 123456] { + let data = encode_avro_int(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 0); + assert_eq!(a.value(1), 1); + assert_eq!(a.value(2), -2); + assert_eq!(a.value(3), 123456); + } + + #[test] + fn test_schema_resolution_promotion_int_to_float() { + let mut dec = decoder_for_promotion(PrimitiveType::Int, PrimitiveType::Float, false); + assert!(matches!(dec, Decoder::Int32ToFloat32(_))); + for v in [0, 42, -7] { + let data = encode_avro_int(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 0.0); + assert_eq!(a.value(1), 42.0); + assert_eq!(a.value(2), -7.0); + } + + #[test] + fn test_schema_resolution_promotion_int_to_double() { + let mut dec = decoder_for_promotion(PrimitiveType::Int, PrimitiveType::Double, false); + assert!(matches!(dec, Decoder::Int32ToFloat64(_))); + for v in [1, -1, 10_000] { + let data = encode_avro_int(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 1.0); + assert_eq!(a.value(1), -1.0); + assert_eq!(a.value(2), 10_000.0); + } + + #[test] + fn test_schema_resolution_promotion_long_to_float() { + let mut dec = decoder_for_promotion(PrimitiveType::Long, PrimitiveType::Float, false); + assert!(matches!(dec, Decoder::Int64ToFloat32(_))); + for v in [0_i64, 1_000_000_i64, -123_i64] { + let data = encode_avro_long(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 0.0); + assert_eq!(a.value(1), 1_000_000.0); + assert_eq!(a.value(2), -123.0); + } + + #[test] + fn test_schema_resolution_promotion_long_to_double() { + let mut dec = decoder_for_promotion(PrimitiveType::Long, PrimitiveType::Double, false); + assert!(matches!(dec, Decoder::Int64ToFloat64(_))); + for v in [2_i64, -2_i64, 9_223_372_i64] { + let data = encode_avro_long(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 2.0); + assert_eq!(a.value(1), -2.0); + assert_eq!(a.value(2), 9_223_372.0); + } + + #[test] + fn test_schema_resolution_promotion_float_to_double() { + let mut dec = decoder_for_promotion(PrimitiveType::Float, PrimitiveType::Double, false); + assert!(matches!(dec, Decoder::Float32ToFloat64(_))); + for v in [0.5_f32, -3.25_f32, 1.0e6_f32] { + let data = v.to_le_bytes().to_vec(); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 0.5_f64); + assert_eq!(a.value(1), -3.25_f64); + assert_eq!(a.value(2), 1.0e6_f64); + } + + #[test] + fn test_schema_resolution_promotion_bytes_to_string_utf8() { + let mut dec = decoder_for_promotion(PrimitiveType::Bytes, PrimitiveType::String, false); + assert!(matches!(dec, Decoder::BytesToString(_, _))); + for s in ["hello", "world", "héllo"] { + let data = encode_avro_bytes(s.as_bytes()); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), "hello"); + assert_eq!(a.value(1), "world"); + assert_eq!(a.value(2), "héllo"); + } + + #[test] + fn test_schema_resolution_promotion_bytes_to_string_utf8view_enabled() { + let mut dec = decoder_for_promotion(PrimitiveType::Bytes, PrimitiveType::String, true); + assert!(matches!(dec, Decoder::BytesToString(_, _))); + let data = encode_avro_bytes("abc".as_bytes()); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), "abc"); + } + + #[test] + fn test_schema_resolution_promotion_string_to_bytes() { + let mut dec = decoder_for_promotion(PrimitiveType::String, PrimitiveType::Bytes, false); + assert!(matches!(dec, Decoder::StringToBytes(_, _))); + for s in ["", "abc", "data"] { + let data = encode_avro_bytes(s.as_bytes()); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), b""); + assert_eq!(a.value(1), b"abc"); + assert_eq!(a.value(2), "data".as_bytes()); + } + + #[test] + fn test_schema_resolution_no_promotion_passthrough_int() { + let ws = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + let rs = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + let field = AvroField::resolve_from_writer_and_reader(&ws, &rs, false, false).unwrap(); + let mut dec = Decoder::try_new(field.data_type()).unwrap(); + assert!(matches!(dec, Decoder::Int32(_))); + for v in [7, -9] { + let data = encode_avro_int(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 7); + assert_eq!(a.value(1), -9); + } + + #[test] + fn test_schema_resolution_illegal_promotion_int_to_boolean_errors() { + let ws = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + let rs = Schema::TypeName(TypeName::Primitive(PrimitiveType::Boolean)); + let res = AvroField::resolve_from_writer_and_reader(&ws, &rs, false, false); + assert!(res.is_err(), "expected error for illegal promotion"); + } + #[test] fn test_map_decoding_one_entry() { let value_type = avro_from_codec(Codec::Utf8);