Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 53 additions & 96 deletions avro/src/ser_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@ use crate::{
bigdecimal::big_decimal_as_bytes,
encode::{encode_int, encode_long},
error::{Details, Error},
schema::{Name, NamesRef, Namespace, RecordSchema, Schema},
schema::{Name, NamesRef, Namespace, RecordField, RecordSchema, Schema},
};
use bigdecimal::BigDecimal;
use serde::ser;
use serde::{Serialize, ser};
use std::{borrow::Cow, io::Write, str::FromStr};

const RECORD_FIELD_INIT_BUFFER_SIZE: usize = 64;
const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024;
const COLLECTION_SERIALIZER_DEFAULT_INIT_ITEM_CAPACITY: usize = 32;
const SINGLE_VALUE_INIT_BUFFER_SIZE: usize = 128;
Expand Down Expand Up @@ -250,68 +249,39 @@ impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeMap<'_, '_, W> {
pub struct SchemaAwareWriteSerializeStruct<'a, 's, W: Write> {
ser: &'a mut SchemaAwareWriteSerializer<'s, W>,
record_schema: &'s RecordSchema,
item_count: usize,
buffered_fields: Vec<Option<Vec<u8>>>,
bytes_written: usize,
}

impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> {
fn new(
ser: &'a mut SchemaAwareWriteSerializer<'s, W>,
record_schema: &'s RecordSchema,
len: usize,
) -> SchemaAwareWriteSerializeStruct<'a, 's, W> {
SchemaAwareWriteSerializeStruct {
ser,
record_schema,
item_count: 0,
buffered_fields: vec![None; len],
bytes_written: 0,
}
}

fn serialize_next_field<T>(&mut self, value: &T) -> Result<(), Error>
fn serialize_next_field<T>(&mut self, field: &RecordField, value: &T) -> Result<(), Error>
where
T: ?Sized + ser::Serialize,
{
let next_field = self.record_schema.fields.get(self.item_count).expect(
"Validity of the next field index was expected to have been checked by the caller",
);

// If we receive fields in order, write them directly to the main writer
let mut value_ser = SchemaAwareWriteSerializer::new(
&mut *self.ser.writer,
&next_field.schema,
&field.schema,
self.ser.names,
self.ser.enclosing_namespace.clone(),
);
self.bytes_written += value.serialize(&mut value_ser)?;

self.item_count += 1;

// Write any buffered data to the stream that has now become next in line
while let Some(buffer) = self
.buffered_fields
.get_mut(self.item_count)
.and_then(|b| b.take())
{
self.bytes_written += self
.ser
.writer
.write(buffer.as_slice())
.map_err(Details::WriteBytes)?;
self.item_count += 1;
}

Ok(())
}

fn end(self) -> Result<usize, Error> {
if self.item_count != self.record_schema.fields.len() {
Err(Details::GetField(self.record_schema.fields[self.item_count].name.clone()).into())
} else {
Ok(self.bytes_written)
}
Ok(self.bytes_written)
}
}

Expand All @@ -323,63 +293,50 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_,
where
T: ?Sized + ser::Serialize,
{
if self.item_count >= self.record_schema.fields.len() {
return Err(Details::FieldName(String::from(key)).into());
}

let next_field = &self.record_schema.fields[self.item_count];
let next_field_matches = match &next_field.aliases {
Some(aliases) => {
key == next_field.name.as_str() || aliases.iter().any(|a| key == a.as_str())
}
None => key == next_field.name.as_str(),
};

if next_field_matches {
self.serialize_next_field(&value).map_err(|e| {
Details::SerializeRecordFieldWithSchema {
field_name: key,
record_schema: Schema::Record(self.record_schema.clone()),
error: Box::new(e),
}
})?;
Ok(())
} else {
if self.item_count < self.record_schema.fields.len() {
for i in self.item_count..self.record_schema.fields.len() {
let field = &self.record_schema.fields[i];
let field_matches = match &field.aliases {
Some(aliases) => {
key == field.name.as_str() || aliases.iter().any(|a| key == a.as_str())
}
None => key == field.name.as_str(),
};

if field_matches {
let mut buffer: Vec<u8> = Vec::with_capacity(RECORD_FIELD_INIT_BUFFER_SIZE);
let mut value_ser = SchemaAwareWriteSerializer::new(
&mut buffer,
&field.schema,
self.ser.names,
self.ser.enclosing_namespace.clone(),
);
value.serialize(&mut value_ser).map_err(|e| {
Details::SerializeRecordFieldWithSchema {
field_name: key,
record_schema: Schema::Record(self.record_schema.clone()),
error: Box::new(e),
}
})?;

self.buffered_fields[i] = Some(buffer);

return Ok(());
let record_field = self
.record_schema
.lookup
.get(key)
.and_then(|idx| self.record_schema.fields.get(*idx));

match record_field {
Some(field) => {
// self.item_count += 1;
self.serialize_next_field(field, value).map_err(|e| {
Details::SerializeRecordFieldWithSchema {
field_name: key,
record_schema: Schema::Record(self.record_schema.clone()),
error: Box::new(e),
}
}
.into()
})
}
None => Err(Details::FieldName(String::from(key)).into()),
}
}

Err(Details::FieldName(String::from(key)).into())
fn skip_field(&mut self, key: &'static str) -> Result<(), Self::Error> {
let skipped_field = self
.record_schema
.lookup
.get(key)
.and_then(|idx| self.record_schema.fields.get(*idx));

if let Some(skipped_field) = skipped_field {
// self.item_count += 1;
skipped_field
.default
.serialize(&mut SchemaAwareWriteSerializer::new(
self.ser.writer,
&skipped_field.schema,
self.ser.names,
self.ser.enclosing_namespace.clone(),
))?;
} else {
return Err(Details::GetField(key.to_string()).into());
}

Ok(())
}

fn end(self) -> Result<Self::Ok, Self::Error> {
Expand Down Expand Up @@ -418,7 +375,9 @@ impl<W: Write> SchemaAwareWriteSerializeTupleStruct<'_, '_, W> {
{
use SchemaAwareWriteSerializeTupleStruct::*;
match self {
Record(record_ser) => record_ser.serialize_next_field(&value),
Record(_record_ser) => {
unimplemented!("Tuple struct serialization to record is not supported!");
}
Array(array_ser) => array_ser.serialize_element(&value),
}
}
Expand Down Expand Up @@ -1127,7 +1086,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
match variant_schema {
Schema::Null => { /* skip */ }
_ => {
encode_int(i as i32, &mut *self.writer)?;
encode_long(i as i64, &mut *self.writer)?;
let mut variant_ser = SchemaAwareWriteSerializer::new(
&mut *self.writer,
variant_schema,
Expand Down Expand Up @@ -1406,7 +1365,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
SchemaAwareWriteSerializeSeq::new(self, &sch.items, Some(len)),
)),
Schema::Record(sch) => Ok(SchemaAwareWriteSerializeTupleStruct::Record(
SchemaAwareWriteSerializeStruct::new(self, sch, len),
SchemaAwareWriteSerializeStruct::new(self, sch),
)),
Schema::Ref { name: ref_name } => {
let ref_schema = self.get_ref_schema(ref_name)?;
Expand Down Expand Up @@ -1543,11 +1502,9 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
};

match schema {
Schema::Record(record_schema) => Ok(SchemaAwareWriteSerializeStruct::new(
self,
record_schema,
len,
)),
Schema::Record(record_schema) => {
Ok(SchemaAwareWriteSerializeStruct::new(self, record_schema))
}
Schema::Ref { name: ref_name } => {
let ref_schema = self.get_ref_schema(ref_name)?;
self.serialize_struct_with_schema(name, len, ref_schema)
Expand Down
130 changes: 130 additions & 0 deletions avro/tests/avro-rs-226.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use apache_avro::{AvroSchema, Schema, Writer, from_value};
use apache_avro_test_helper::TestResult;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::fmt::Debug;

fn ser_deser<T>(schema: &Schema, record: T) -> TestResult
where
T: Serialize + DeserializeOwned + Debug + PartialEq + Clone,
{
let record2 = record.clone();
let mut writer = Writer::new(schema, vec![]);
writer.append_ser(record)?;
let bytes_written = writer.into_inner()?;

let reader = apache_avro::Reader::new(&bytes_written[..])?;
for value in reader {
let value = value?;
let deserialized = from_value::<T>(&value)?;
assert_eq!(deserialized, record2);
}

Ok(())
}

#[test]
fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_middle_field() -> TestResult {
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
struct T {
x: Option<i8>,
#[serde(skip_serializing_if = "Option::is_none")]
y: Option<String>,
z: Option<i8>,
}

ser_deser::<T>(
&T::get_schema(),
T {
x: None,
y: None,
z: Some(1),
},
)
}

#[test]
fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_first_field() -> TestResult {
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
struct T {
#[serde(skip_serializing_if = "Option::is_none")]
x: Option<i8>,
y: Option<String>,
z: Option<i8>,
}

ser_deser::<T>(
&T::get_schema(),
T {
x: None,
y: None,
z: Some(1),
},
)
}

#[test]
fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_last_field() -> TestResult {
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
struct T {
x: Option<i8>,
y: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
z: Option<i8>,
}

ser_deser::<T>(
&T::get_schema(),
T {
x: Some(0),
y: None,
z: None,
},
)
}

#[test]
#[ignore = "This test should be re-enabled once the serde-driven deserialization is implemented! PR #227"]
fn avro_rs_226_index_out_of_bounds_with_serde_skip_multiple_fields() -> TestResult {
#[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
struct T {
no_skip1: Option<i8>,
#[serde(skip_serializing)]
skip_serializing: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
skip_serializing_if: Option<i8>,
#[serde(skip_deserializing)]
skip_deserializing: Option<String>,
#[serde(skip)]
skip: Option<String>,
no_skip2: Option<i8>,
}

ser_deser::<T>(
&T::get_schema(),
T {
no_skip1: Some(1),
skip_serializing: None,
skip_serializing_if: None,
skip_deserializing: None,
skip: None,
no_skip2: Some(2),
},
)
}