Skip to content

Commit

Permalink
feat: max_size: various small refactoring
Browse files Browse the repository at this point in the history
* Change is_zero_size to short-circuit internal all function when
  first element which isn’t zero sized is found.

* Change is_zero_size to use vector of str slices as stack.  There’s
  no need to allocate strings.  This also makes it possible to share
  the same stack as max_serialized_size_impl.

* Introduce Recursive error type to better encode error condition of
  is_zero_size.

* Change max_serialized_size’s count argument to NonZeroUsize to
  encode in type system that the count is in fact never zero.  In
  places where it might end up zero, the function short-circuits.
  • Loading branch information
mina86 committed Sep 17, 2023
1 parent d22259a commit 8c00f72
Showing 1 changed file with 44 additions and 33 deletions.
77 changes: 44 additions & 33 deletions borsh/src/schema/container_ext/max_size.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
use super::{BorshSchemaContainer, Declaration, Definition, Fields};
use crate::__private::maybestd::{string::ToString, vec::Vec};

use core::num::NonZeroUsize;

/// NonZeroUsize of value one.
// TODO: Replace usage by NonZeroUsize::MIN once MSRV is 1.70+.
const ONE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(1) };

impl BorshSchemaContainer {
/// Returns the largest possible size of a serialised object based solely on its type.
///
Expand Down Expand Up @@ -39,7 +45,7 @@ impl BorshSchemaContainer {
/// ```
pub fn max_serialized_size(&self) -> core::result::Result<usize, SchemaMaxSerializedSizeError> {
let mut stack = Vec::new();
max_serialized_size_impl(1, self.declaration(), self, &mut stack)
max_serialized_size_impl(ONE, self.declaration(), self, &mut stack)
}
}

Expand All @@ -64,7 +70,7 @@ pub enum SchemaMaxSerializedSizeError {

/// Implementation of [`BorshSchema::max_serialized_size`].
fn max_serialized_size_impl<'a>(
count: usize,
count: NonZeroUsize,
declaration: &'a str,
schema: &'a BorshSchemaContainer,
stack: &mut Vec<&'a str>,
Expand All @@ -73,28 +79,29 @@ fn max_serialized_size_impl<'a>(

/// Maximum number of elements in a vector or length of a string which can
/// be serialised.
const MAX_LEN: usize = u32::MAX as usize;
const MAX_LEN: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(u32::MAX as usize) };

fn add(x: usize, y: usize) -> core::result::Result<usize, SchemaMaxSerializedSizeError> {
x.checked_add(y)
.ok_or(SchemaMaxSerializedSizeError::Overflow)
}

fn mul(x: usize, y: usize) -> core::result::Result<usize, SchemaMaxSerializedSizeError> {
x.checked_mul(y)
fn mul(x: NonZeroUsize, y: usize) -> core::result::Result<usize, SchemaMaxSerializedSizeError> {
x.get()
.checked_mul(y)
.ok_or(SchemaMaxSerializedSizeError::Overflow)
}

/// Calculates max serialised size of a tuple with given members.
fn tuple<'a>(
count: usize,
count: NonZeroUsize,
elements: impl core::iter::IntoIterator<Item = &'a Declaration>,
schema: &'a BorshSchemaContainer,
stack: &mut Vec<&'a str>,
) -> ::core::result::Result<usize, SchemaMaxSerializedSizeError> {
let mut sum: usize = 0;
for el in elements {
sum = add(sum, max_serialized_size_impl(1, el, schema, stack)?)?;
sum = add(sum, max_serialized_size_impl(ONE, el, schema, stack)?)?;
}
mul(count, sum)
}
Expand All @@ -110,15 +117,12 @@ fn max_serialized_size_impl<'a>(
// overflows, check if array’s element is zero-sized.
let count = usize::try_from(*length)
.ok()
.and_then(|len| len.checked_mul(count));
.and_then(|len| len.checked_mul(count.get()))
.map(NonZeroUsize::new);
match count {
Some(0) => Ok(0),
Some(count) => max_serialized_size_impl(count, elements, schema, stack),
None if is_zero_size(elements, schema)
.map_err(|_err| SchemaMaxSerializedSizeError::Recursive)? =>
{
Ok(0)
}
Some(None) => Ok(0),
Some(Some(count)) => max_serialized_size_impl(count, elements, schema, stack),
None if is_zero_size_impl(elements.as_str(), schema, stack)? => Ok(0),
None => Err(SchemaMaxSerializedSizeError::Overflow),
}
}
Expand All @@ -135,7 +139,7 @@ fn max_serialized_size_impl<'a>(
}) => {
let mut max = 0;
for (_, variant) in variants {
let sz = max_serialized_size_impl(1, variant, schema, stack)?;
let sz = max_serialized_size_impl(ONE, variant, schema, stack)?;
max = max.max(sz);
}
max.checked_add(usize::from(*tag_width))
Expand All @@ -154,14 +158,14 @@ fn max_serialized_size_impl<'a>(

// Primitive types.
Err("nil") => Ok(0),
Err("bool" | "i8" | "u8" | "nonzero_i8" | "nonzero_u8") => Ok(count),
Err("bool" | "i8" | "u8" | "nonzero_i8" | "nonzero_u8") => Ok(count.get()),
Err("i16" | "u16" | "nonzero_i16" | "nonzero_u16") => mul(count, 2),
Err("i32" | "u32" | "f32" | "nonzero_i32" | "nonzero_u32") => mul(count, 4),
Err("i64" | "u64" | "f64" | "nonzero_i64" | "nonzero_u64") => mul(count, 8),
Err("i128" | "u128" | "nonzero_i128" | "nonzero_u128") => mul(count, 16),

// string is just Vec<u8>
Err("string") => mul(count, add(MAX_LEN, 4)?),
Err("string") => mul(count, add(MAX_LEN.get(), 4)?),

Err(declaration) => Err(SchemaMaxSerializedSizeError::MissingDefinition(
declaration.to_string(),
Expand All @@ -187,37 +191,44 @@ fn max_serialized_size_impl<'a>(
pub(super) fn is_zero_size(
declaration: &Declaration,
schema: &BorshSchemaContainer,
) -> Result<bool, ()> {
) -> Result<bool, Recursive> {
let mut stack = Vec::new();
is_zero_size_impl(declaration, schema, &mut stack)
}
const RECURSIVE: () = ();

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(super) struct Recursive;

impl From<Recursive> for SchemaMaxSerializedSizeError {
fn from(_: Recursive) -> Self {
Self::Recursive
}
}

fn is_zero_size_impl<'a>(
declaration: &'a str,
schema: &'a BorshSchemaContainer,
stack: &mut Vec<Declaration>,
) -> Result<bool, ()> {
fn all<T>(
stack: &mut Vec<&'a str>,
) -> Result<bool, Recursive> {
fn all<'a, T: 'a>(
iter: impl Iterator<Item = T>,
f_key: impl Fn(&T) -> &Declaration,
schema: &BorshSchemaContainer,
stack: &mut Vec<Declaration>,
) -> Result<bool, ()> {
let mut all = true;
f_key: impl Fn(&T) -> &'a Declaration,
schema: &'a BorshSchemaContainer,
stack: &mut Vec<&'a str>,
) -> Result<bool, Recursive> {
for element in iter {
let declaration = f_key(&element);
if !is_zero_size_impl(declaration.as_str(), schema, stack)? {
all = false;
return Ok(false);
}
}
Ok(all)
Ok(true)
}

if stack.iter().any(|dec| *dec == declaration) {
return Err(RECURSIVE);
return Err(Recursive);
}
stack.push(declaration.to_string());
stack.push(declaration);

let res = match schema.get_definition(declaration).ok_or(declaration) {
Ok(Definition::Array { length, elements }) => {
Expand Down Expand Up @@ -308,7 +319,7 @@ mod tests {
struct RecursiveNoExitStructUnnamed(Box<RecursiveNoExitStructUnnamed>);

let schema = BorshSchemaContainer::for_type::<RecursiveNoExitStructUnnamed>();
assert_eq!(Err(()), is_zero_size(schema.declaration(), &schema));
assert_eq!(Err(Recursive), is_zero_size(schema.declaration(), &schema));
}

#[test]
Expand Down

0 comments on commit 8c00f72

Please sign in to comment.