diff --git a/wincode/src/schema/containers.rs b/wincode/src/schema/containers.rs index 6084dba6..3449a66f 100644 --- a/wincode/src/schema/containers.rs +++ b/wincode/src/schema/containers.rs @@ -59,8 +59,17 @@ #[cfg(all(feature = "alloc", target_has_atomic = "ptr"))] use alloc::sync::Arc as AllocArc; use { - crate::{TypeMeta, config::ConfigCore, error::ReadResult, io::Reader, schema::SchemaRead}, + crate::{ + TypeMeta, + config::ConfigCore, + error::{ReadResult, WriteResult}, + io::{Reader, Writer}, + len::SeqLen, + schema::{SchemaRead, SchemaWrite, size_of_elem_iter, write_elem_iter}, + }, core::{ + borrow::Borrow, + marker::PhantomData, mem::{self, MaybeUninit}, ptr, }, @@ -69,16 +78,9 @@ use { use { crate::{ context, - error::WriteResult, - io::Writer, - len::SeqLen, - schema::{ - SchemaReadContext, SchemaWrite, size_of_elem_iter, size_of_elem_slice, write_elem_iter, - write_elem_slice_prealloc_check, - }, + schema::{SchemaReadContext, size_of_elem_slice, write_elem_slice_prealloc_check}, }, alloc::{boxed::Box as AllocBox, collections, rc::Rc as AllocRc, vec}, - core::marker::PhantomData, }; /// A [`Vec`](std::vec::Vec) with a customizable length encoding. @@ -517,6 +519,150 @@ where } } +/// Newtype that collects a fallible iterator into `Result` while preserving `size_hint`. +/// +/// Unlike `collect::>()`, which loses the size hint on error, this type +/// drives `V::from_iter` through an adaptor that stops on the first error but keeps +/// `size_hint` accurate so that `V` can preallocate its full expected capacity. +struct ResultPrealloc(Result); + +impl> FromIterator> for ResultPrealloc { + fn from_iter>>(iter: I) -> ResultPrealloc { + struct Iter { + inner: I, + error: Option, + } + + impl>, T, E> Iterator for Iter { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + self.inner.next()?.map_err(|e| self.error = Some(e)).ok() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } + } + + let mut iter = Iter { + inner: iter.into_iter(), + error: None, + }; + let result = V::from_iter(&mut iter); + ResultPrealloc(iter.error.map_or(Ok(result), Err)) + } +} + +/// Extension trait that adds [`collect_result_prealloc`](CollectResultExt::collect_result_prealloc) +/// to any fallible iterator, collecting into `Result` with preallocation-friendly size hints. +trait CollectResultExt: Iterator> { + #[inline] + fn collect_result_prealloc>(self) -> Result + where + Self: Sized, + { + self.collect::>().0 + } +} +impl CollectResultExt for I where I: Iterator> {} + +/// A generic sequence schema for custom collections that implement +/// [`FromIterator`] (for reading) and whose references implement +/// [`IntoIterator`] with an [`ExactSizeIterator`] (for writing). +/// +/// Works for both element sequences and key-value maps: +/// - For element collections (sets, ordered sets, etc.) whose reference +/// iterators yield `&T`, the schema for `T` is used directly. +/// - For map-like collections whose reference iterators yield `(&K, &V)` pairs, +/// the pair itself acts as the schema (automatically satisfied when `K` and +/// `V` implement `SchemaWrite`). +/// +/// Intended for external collection types that cannot have a dedicated +/// schema impl added directly. Unlike [`Vec`], [`VecDeque`], and [`BinaryHeap`], this +/// container relies on the collection's [`FromIterator`] impl rather than +/// writing directly into preallocated memory. +/// +/// # Allocation efficiency +/// +/// During deserialization, the iterator passed to [`FromIterator`] has a +/// precise [`size_hint`](Iterator::size_hint) matching the number of elements +/// produced, unless a read error is encountered. Collections whose +/// [`FromIterator`] implementation uses the size hint to preallocate capacity +/// will allocate optimally. Collections that do not use it will not benefit. +/// +/// # Examples +/// +/// ```ignore +/// use some_crate::{IndexSet, MyMap}; +/// use wincode::{SchemaRead, SchemaWrite, containers::FromIntoIterator, len::BincodeLen}; +/// +/// #[derive(SchemaRead, SchemaWrite)] +/// struct MyData { +/// #[wincode(with = "FromIntoIterator, BincodeLen>")] +/// items: IndexSet, +/// #[wincode(with = "FromIntoIterator, BincodeLen>")] +/// map: MyMap, +/// } +/// ``` +pub struct FromIntoIterator(PhantomData<(Coll, Len)>); + +unsafe impl SchemaWrite for FromIntoIterator +where + Len: SeqLen, + Coll: IntoIterator, + for<'a> &'a Coll: IntoIterator, IntoIter: ExactSizeIterator>, + for<'a> <&'a Coll as IntoIterator>::Item: + Borrow<<<&'a Coll as IntoIterator>::Item as SchemaWrite>::Src>, +{ + type Src = Coll; + + #[inline] + fn size_of(src: &Coll) -> WriteResult { + size_of_elem_iter::<<&Coll as IntoIterator>::Item, Len, C>(src.into_iter()) + } + + #[inline] + fn write(writer: impl Writer, src: &Coll) -> WriteResult<()> { + let iter = src.into_iter(); + Len::prealloc_check::(iter.len())?; + write_elem_iter::<<&Coll as IntoIterator>::Item, Len, C>(writer, iter) + } +} + +unsafe impl<'de, Coll, Len, C: ConfigCore> SchemaRead<'de, C> for FromIntoIterator +where + Len: SeqLen, + Coll: IntoIterator>, + Coll: FromIterator<>::Dst>, +{ + type Dst = Coll; + + #[inline] + fn read(mut reader: impl Reader<'de>, dst: &mut MaybeUninit) -> ReadResult<()> { + let len = + Len::read_prealloc_check::<>::Dst>(reader.by_ref())?; + + let coll = if let TypeMeta::Static { size, .. } = Coll::Item::TYPE_META { + #[allow(clippy::arithmetic_side_effects)] + // SAFETY: `Item::TYPE_META` specifies a static size, so `len` reads of `Item::Dst` + // will consume `size * len` bytes, fully consuming the trusted window. + let mut reader = unsafe { reader.as_trusted_for(size * len) }?; + (0..len) + .map(|_| Coll::Item::get(reader.by_ref())) + .collect_result_prealloc()? + } else { + (0..len) + .map(|_| Coll::Item::get(reader.by_ref())) + .collect_result_prealloc()? + }; + dst.write(coll); + Ok(()) + } +} + /// Decode `slice.len()` items of `T` into contiguous, uninitialized memory. /// /// Errors if fewer than `slice.len()` items are available in the [`Reader`] diff --git a/wincode/src/schema/mod.rs b/wincode/src/schema/mod.rs index 04727098..c28d235d 100644 --- a/wincode/src/schema/mod.rs +++ b/wincode/src/schema/mod.rs @@ -52,7 +52,7 @@ use { io::*, len::SeqLen, }, - core::mem::MaybeUninit, + core::{borrow::Borrow, mem::MaybeUninit}, }; pub mod containers; @@ -422,13 +422,13 @@ impl SchemaReadOwned for T where T: for<'de> SchemaRead<'de #[inline(always)] #[allow(clippy::arithmetic_side_effects)] -fn size_of_elem_iter<'a, T, Len, C>( - value: impl ExactSizeIterator, +fn size_of_elem_iter( + value: impl ExactSizeIterator>, ) -> WriteResult where C: ConfigCore, Len: SeqLen, - T: SchemaWrite + 'a, + T: SchemaWrite, { if let TypeMeta::Static { size, .. } = T::TYPE_META { return Ok(Len::write_bytes_needed(value.len())? + size * value.len()); @@ -436,7 +436,7 @@ where // Extremely unlikely a type-in-memory's size will overflow usize::MAX. Ok(Len::write_bytes_needed(value.len())? + (value - .map(T::size_of) + .map(|x| T::size_of(x.borrow())) .try_fold(0usize, |acc, x| x.map(|x| acc + x))?)) } @@ -454,14 +454,14 @@ where } #[inline(always)] -fn write_elem_iter<'a, T, Len, C>( +fn write_elem_iter( mut writer: impl Writer, - src: impl ExactSizeIterator, + src: impl ExactSizeIterator>, ) -> WriteResult<()> where C: ConfigCore, Len: SeqLen, - T: SchemaWrite + 'a, + T: SchemaWrite, { if let TypeMeta::Static { size, .. } = T::TYPE_META { #[allow(clippy::arithmetic_side_effects)] @@ -472,7 +472,7 @@ where let mut writer = unsafe { writer.as_trusted_for(needed) }?; Len::write(writer.by_ref(), src.len())?; for item in src { - T::write(writer.by_ref(), item)?; + T::write(writer.by_ref(), item.borrow())?; } writer.finish()?; return Ok(()); @@ -480,7 +480,7 @@ where Len::write(writer.by_ref(), src.len())?; for item in src { - T::write(writer.by_ref(), item)?; + T::write(writer.by_ref(), item.borrow())?; } Ok(()) } @@ -2678,13 +2678,24 @@ mod tests { prop_assert_eq!(&test_data, &wincode_deserialized); prop_assert_eq!(wincode_deserialized, bincode_deserialized); + type TestMapSeq = containers::FromIntoIterator; + let test_seq_serialized = TestMapSeq::serialize(&test_data).unwrap(); + assert_eq!(test_seq_serialized, wincode_serialized); + let test_seq_deserialized = TestMapSeq::deserialize(&test_seq_serialized).unwrap(); + prop_assert_eq!(&test_data, &test_seq_deserialized); + type RegularMap = HashMap>; let regular_deserialized: RegularMap = deserialize(&wincode_serialized).unwrap(); let regular_serialized = serialize(®ular_deserialized).unwrap(); let test_deserialized: TestMap = deserialize(®ular_serialized).unwrap(); prop_assert_eq!(test_data, test_deserialized); - } + type RegularMapSeq = containers::FromIntoIterator; + let regular_seq_serialized = RegularMapSeq::serialize(®ular_deserialized).unwrap(); + assert_eq!(regular_serialized, regular_seq_serialized); + let regular_seq_deserialized = RegularMapSeq::deserialize(®ular_seq_serialized).unwrap(); + prop_assert_eq!(®ular_deserialized, ®ular_seq_deserialized); + } #[test] fn test_btree_map_zero_copy(map in proptest::collection::btree_map(any::(), any::(), 0..=100)) {