diff --git a/avro/src/ser_schema.rs b/avro/src/ser_schema.rs index cc1a4e0a..e62d21d1 100644 --- a/avro/src/ser_schema.rs +++ b/avro/src/ser_schema.rs @@ -22,13 +22,12 @@ use crate::{ bigdecimal::big_decimal_as_bytes, encode::{encode_int, encode_long}, error::{Details, Error}, - schema::{Name, NamesRef, Namespace, RecordSchema, Schema}, + schema::{Name, NamesRef, Namespace, RecordField, RecordSchema, Schema}, }; use bigdecimal::BigDecimal; -use serde::ser; +use serde::{Serialize, ser}; use std::{borrow::Cow, io::Write, str::FromStr}; -const RECORD_FIELD_INIT_BUFFER_SIZE: usize = 64; const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024; const COLLECTION_SERIALIZER_DEFAULT_INIT_ITEM_CAPACITY: usize = 32; const SINGLE_VALUE_INIT_BUFFER_SIZE: usize = 128; @@ -250,8 +249,6 @@ impl ser::SerializeMap for SchemaAwareWriteSerializeMap<'_, '_, W> { pub struct SchemaAwareWriteSerializeStruct<'a, 's, W: Write> { ser: &'a mut SchemaAwareWriteSerializer<'s, W>, record_schema: &'s RecordSchema, - item_count: usize, - buffered_fields: Vec>>, bytes_written: usize, } @@ -259,59 +256,32 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> { fn new( ser: &'a mut SchemaAwareWriteSerializer<'s, W>, record_schema: &'s RecordSchema, - len: usize, ) -> SchemaAwareWriteSerializeStruct<'a, 's, W> { SchemaAwareWriteSerializeStruct { ser, record_schema, - item_count: 0, - buffered_fields: vec![None; len], bytes_written: 0, } } - fn serialize_next_field(&mut self, value: &T) -> Result<(), Error> + fn serialize_next_field(&mut self, field: &RecordField, value: &T) -> Result<(), Error> where T: ?Sized + ser::Serialize, { - let next_field = self.record_schema.fields.get(self.item_count).expect( - "Validity of the next field index was expected to have been checked by the caller", - ); - // If we receive fields in order, write them directly to the main writer let mut value_ser = SchemaAwareWriteSerializer::new( &mut *self.ser.writer, - &next_field.schema, + &field.schema, self.ser.names, self.ser.enclosing_namespace.clone(), ); self.bytes_written += value.serialize(&mut value_ser)?; - self.item_count += 1; - - // Write any buffered data to the stream that has now become next in line - while let Some(buffer) = self - .buffered_fields - .get_mut(self.item_count) - .and_then(|b| b.take()) - { - self.bytes_written += self - .ser - .writer - .write(buffer.as_slice()) - .map_err(Details::WriteBytes)?; - self.item_count += 1; - } - Ok(()) } fn end(self) -> Result { - if self.item_count != self.record_schema.fields.len() { - Err(Details::GetField(self.record_schema.fields[self.item_count].name.clone()).into()) - } else { - Ok(self.bytes_written) - } + Ok(self.bytes_written) } } @@ -323,63 +293,50 @@ impl ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_, where T: ?Sized + ser::Serialize, { - if self.item_count >= self.record_schema.fields.len() { - return Err(Details::FieldName(String::from(key)).into()); - } - - let next_field = &self.record_schema.fields[self.item_count]; - let next_field_matches = match &next_field.aliases { - Some(aliases) => { - key == next_field.name.as_str() || aliases.iter().any(|a| key == a.as_str()) - } - None => key == next_field.name.as_str(), - }; - - if next_field_matches { - self.serialize_next_field(&value).map_err(|e| { - Details::SerializeRecordFieldWithSchema { - field_name: key, - record_schema: Schema::Record(self.record_schema.clone()), - error: Box::new(e), - } - })?; - Ok(()) - } else { - if self.item_count < self.record_schema.fields.len() { - for i in self.item_count..self.record_schema.fields.len() { - let field = &self.record_schema.fields[i]; - let field_matches = match &field.aliases { - Some(aliases) => { - key == field.name.as_str() || aliases.iter().any(|a| key == a.as_str()) - } - None => key == field.name.as_str(), - }; - - if field_matches { - let mut buffer: Vec = Vec::with_capacity(RECORD_FIELD_INIT_BUFFER_SIZE); - let mut value_ser = SchemaAwareWriteSerializer::new( - &mut buffer, - &field.schema, - self.ser.names, - self.ser.enclosing_namespace.clone(), - ); - value.serialize(&mut value_ser).map_err(|e| { - Details::SerializeRecordFieldWithSchema { - field_name: key, - record_schema: Schema::Record(self.record_schema.clone()), - error: Box::new(e), - } - })?; - - self.buffered_fields[i] = Some(buffer); - - return Ok(()); + let record_field = self + .record_schema + .lookup + .get(key) + .and_then(|idx| self.record_schema.fields.get(*idx)); + + match record_field { + Some(field) => { + // self.item_count += 1; + self.serialize_next_field(field, value).map_err(|e| { + Details::SerializeRecordFieldWithSchema { + field_name: key, + record_schema: Schema::Record(self.record_schema.clone()), + error: Box::new(e), } - } + .into() + }) } + None => Err(Details::FieldName(String::from(key)).into()), + } + } - Err(Details::FieldName(String::from(key)).into()) + fn skip_field(&mut self, key: &'static str) -> Result<(), Self::Error> { + let skipped_field = self + .record_schema + .lookup + .get(key) + .and_then(|idx| self.record_schema.fields.get(*idx)); + + if let Some(skipped_field) = skipped_field { + // self.item_count += 1; + skipped_field + .default + .serialize(&mut SchemaAwareWriteSerializer::new( + self.ser.writer, + &skipped_field.schema, + self.ser.names, + self.ser.enclosing_namespace.clone(), + ))?; + } else { + return Err(Details::GetField(key.to_string()).into()); } + + Ok(()) } fn end(self) -> Result { @@ -418,7 +375,9 @@ impl SchemaAwareWriteSerializeTupleStruct<'_, '_, W> { { use SchemaAwareWriteSerializeTupleStruct::*; match self { - Record(record_ser) => record_ser.serialize_next_field(&value), + Record(_record_ser) => { + unimplemented!("Tuple struct serialization to record is not supported!"); + } Array(array_ser) => array_ser.serialize_element(&value), } } @@ -1127,7 +1086,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { match variant_schema { Schema::Null => { /* skip */ } _ => { - encode_int(i as i32, &mut *self.writer)?; + encode_long(i as i64, &mut *self.writer)?; let mut variant_ser = SchemaAwareWriteSerializer::new( &mut *self.writer, variant_schema, @@ -1406,7 +1365,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { SchemaAwareWriteSerializeSeq::new(self, &sch.items, Some(len)), )), Schema::Record(sch) => Ok(SchemaAwareWriteSerializeTupleStruct::Record( - SchemaAwareWriteSerializeStruct::new(self, sch, len), + SchemaAwareWriteSerializeStruct::new(self, sch), )), Schema::Ref { name: ref_name } => { let ref_schema = self.get_ref_schema(ref_name)?; @@ -1543,11 +1502,9 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { }; match schema { - Schema::Record(record_schema) => Ok(SchemaAwareWriteSerializeStruct::new( - self, - record_schema, - len, - )), + Schema::Record(record_schema) => { + Ok(SchemaAwareWriteSerializeStruct::new(self, record_schema)) + } Schema::Ref { name: ref_name } => { let ref_schema = self.get_ref_schema(ref_name)?; self.serialize_struct_with_schema(name, len, ref_schema) diff --git a/avro/tests/avro-rs-226.rs b/avro/tests/avro-rs-226.rs new file mode 100644 index 00000000..10dc80db --- /dev/null +++ b/avro/tests/avro-rs-226.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use apache_avro::{AvroSchema, Schema, Writer, from_value}; +use apache_avro_test_helper::TestResult; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use std::fmt::Debug; + +fn ser_deser(schema: &Schema, record: T) -> TestResult +where + T: Serialize + DeserializeOwned + Debug + PartialEq + Clone, +{ + let record2 = record.clone(); + let mut writer = Writer::new(schema, vec![]); + writer.append_ser(record)?; + let bytes_written = writer.into_inner()?; + + let reader = apache_avro::Reader::new(&bytes_written[..])?; + for value in reader { + let value = value?; + let deserialized = from_value::(&value)?; + assert_eq!(deserialized, record2); + } + + Ok(()) +} + +#[test] +fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_middle_field() -> TestResult { + #[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)] + struct T { + x: Option, + #[serde(skip_serializing_if = "Option::is_none")] + y: Option, + z: Option, + } + + ser_deser::( + &T::get_schema(), + T { + x: None, + y: None, + z: Some(1), + }, + ) +} + +#[test] +fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_first_field() -> TestResult { + #[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)] + struct T { + #[serde(skip_serializing_if = "Option::is_none")] + x: Option, + y: Option, + z: Option, + } + + ser_deser::( + &T::get_schema(), + T { + x: None, + y: None, + z: Some(1), + }, + ) +} + +#[test] +fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_last_field() -> TestResult { + #[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)] + struct T { + x: Option, + y: Option, + #[serde(skip_serializing_if = "Option::is_none")] + z: Option, + } + + ser_deser::( + &T::get_schema(), + T { + x: Some(0), + y: None, + z: None, + }, + ) +} + +#[test] +#[ignore = "This test should be re-enabled once the serde-driven deserialization is implemented! PR #227"] +fn avro_rs_226_index_out_of_bounds_with_serde_skip_multiple_fields() -> TestResult { + #[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)] + struct T { + no_skip1: Option, + #[serde(skip_serializing)] + skip_serializing: Option, + #[serde(skip_serializing_if = "Option::is_none")] + skip_serializing_if: Option, + #[serde(skip_deserializing)] + skip_deserializing: Option, + #[serde(skip)] + skip: Option, + no_skip2: Option, + } + + ser_deser::( + &T::get_schema(), + T { + no_skip1: Some(1), + skip_serializing: None, + skip_serializing_if: None, + skip_deserializing: None, + skip: None, + no_skip2: Some(2), + }, + ) +}