diff --git a/protocols/v2/serde-sv2/src/de.rs b/protocols/v2/serde-sv2/src/de.rs index c71ddc1cf1a1d..36cacbefe7196 100644 --- a/protocols/v2/serde-sv2/src/de.rs +++ b/protocols/v2/serde-sv2/src/de.rs @@ -135,6 +135,11 @@ impl<'de> Deserializer<'de> { Ok(u256) } + fn parse_signature(&'de self) -> Result<&'de [u8; 64]> { + let signature: &[u8; 64] = self.get_slice(64)?.try_into().unwrap(); + Ok(signature) + } + fn parse_string(&'de self) -> Result<&'de str> { let len = self.parse_u8()?; let str_ = self.get_slice(len as usize)?; @@ -146,36 +151,6 @@ impl<'de> Deserializer<'de> { Ok(self.get_slice(len as usize)?) } - //// TODO REMOVE!! ///// - - // Signature is parsed as Signature rather then [u8; 4] or [u8] as Signature in Sv2 rapresent - // a big int and not an array of bytes - // fn parse_signature(&'de self) -> Result { - // let signature: &[u8; 64] = self.get_slice(32)?.try_into().unwrap(); - // Ok(Signature(signature)) - // } - - // fn parse_b0255(&'de self) -> Result<&'de [u8]> { - // let len = self.parse_u8()?; - // Ok(self.get_slice(len as usize)?) - // } - - // fn parse_b064k(&'de self) -> Result<&'de [u8]> { - // let len = self.parse_u16()?; - // Ok(self.get_slice(len as usize)?) - // } - - - // fn parse_bytes(&'de self, len: usize) -> Result<&'de [u8]> { - // Ok(self.get_slice(len as usize)?) - // } - - // // Pubkey is parsed as Pubkey and not as [u8; 4] or [u8] as Pubkey in Sv2 rapresent a big int and - // // not an array of bytes - // fn parse_pubkey(&'de self) -> Result { - // let pk: &[u8; 32] = self.get_slice(32)?.try_into().unwrap(); - // Ok(U256(pk)) - // } } impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { @@ -253,6 +228,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { match _name { "U24" => visitor.visit_u32(self.parse_u24()?), "U256" => visitor.visit_bytes(self.parse_u256()?), + "Signature" => visitor.visit_bytes(self.parse_signature()?), "B016M" => visitor.visit_bytes(self.parse_b016m()?), "Seq0255" => Ok(visitor.visit_seq(Seq::new(&mut self, Sv2Seq::S255)?)?), "Seq064K" => Ok(visitor.visit_seq(Seq::new(&mut self, Sv2Seq::S64k)?)?), @@ -506,3 +482,108 @@ fn test_struct() { assert_eq!(deserialized, expected); } + +#[test] +fn test_u256() { + use serde::Serialize; + + let u256: crate::sv2_primitives::U256 = [6; 32].into(); + + #[derive(Deserialize, Serialize, PartialEq, Debug)] + struct Test { + a: crate::sv2_primitives::U256, + } + + let expected = Test { + a: u256, + }; + + let mut bytes = crate::ser::to_bytes(&expected).unwrap(); + let deserialized: Test = from_bytes(&mut bytes[..]).unwrap(); + + assert_eq!(deserialized, expected); +} + +#[test] +fn test_signature() { + use serde::Serialize; + + let s: crate::sv2_primitives::Signature = [6; 64].into(); + + #[derive(Deserialize, Serialize, PartialEq, Debug)] + struct Test { + a: crate::sv2_primitives::Signature, + } + + let expected = Test { + a: s, + }; + + let mut bytes = crate::ser::to_bytes(&expected).unwrap(); + let deserialized: Test = from_bytes(&mut bytes[..]).unwrap(); + + assert_eq!(deserialized, expected); +} + +#[test] +fn test_b016m() { + use serde::Serialize; + + let b: crate::sv2_primitives::B016M = vec![6; 3].try_into().unwrap(); + + #[derive(Deserialize, Serialize, PartialEq, Debug)] + struct Test { + a: crate::sv2_primitives::B016M, + } + + let expected = Test { + a: b, + }; + + let mut bytes = crate::ser::to_bytes(&expected).unwrap(); + let deserialized: Test = from_bytes(&mut bytes[..]).unwrap(); + + assert_eq!(deserialized, expected); +} + +#[test] +fn test_seq0255() { + use serde::Serialize; + + let s: crate::sv2_primitives::Seq0255 = vec![true; 3].try_into().unwrap(); + + #[derive(Deserialize, Serialize, PartialEq, Debug)] + struct Test { + a: crate::sv2_primitives::Seq0255, + } + + let expected = Test { + a: s, + }; + + let mut bytes = crate::ser::to_bytes(&expected).unwrap(); + let deserialized: Test = from_bytes(&mut bytes[..]).unwrap(); + + assert_eq!(deserialized, expected); +} + +#[test] +fn test_seq064k() { + use serde::Serialize; + + let s: crate::sv2_primitives::Seq0255 = vec![9; 3].try_into().unwrap(); + + #[derive(Deserialize, Serialize, PartialEq, Debug)] + struct Test { + a: crate::sv2_primitives::Seq0255, + } + + let expected = Test { + a: s, + }; + + let mut bytes = crate::ser::to_bytes(&expected).unwrap(); + let deserialized: Test = from_bytes(&mut bytes[..]).unwrap(); + + assert_eq!(deserialized, expected); +} diff --git a/protocols/v2/serde-sv2/src/error.rs b/protocols/v2/serde-sv2/src/error.rs index 504d55bd803a3..ddb0f377f1c49 100644 --- a/protocols/v2/serde-sv2/src/error.rs +++ b/protocols/v2/serde-sv2/src/error.rs @@ -18,25 +18,6 @@ pub enum Error { // field is missing. Message(String), - // Zero or more variants that can be created directly by the Serializer and - // Deserializer without going through `ser::Error` and `de::Error`. These - // are specific to the format, in this case JSON. - Eof, - Syntax, - ExpectedBoolean, - ExpectedInteger, - ExpectedString, - ExpectedNull, - ExpectedArray, - ExpectedArrayComma, - ExpectedArrayEnd, - ExpectedMap, - ExpectedMapColon, - ExpectedMapComma, - ExpectedMapEnd, - ExpectedEnum, - TrailingCharacters, - StringLenBiggerThan256, InvalidUTF8, LenBiggerThan16M, @@ -63,9 +44,9 @@ impl Display for Error { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { match self { Error::Message(msg) => formatter.write_str(msg), - Error::Eof => formatter.write_str("unexpected end of input"), + Error::WriteError => formatter.write_str("write error"), + Error::ReadError => formatter.write_str("read error"), _ => panic!(), - /* and so forth */ } } } diff --git a/protocols/v2/serde-sv2/src/sv2_primitives.rs b/protocols/v2/serde-sv2/src/sv2_primitives.rs index 45da65e126ad7..8cf112b35b791 100644 --- a/protocols/v2/serde-sv2/src/sv2_primitives.rs +++ b/protocols/v2/serde-sv2/src/sv2_primitives.rs @@ -4,7 +4,7 @@ //! use crate::error::Error; use serde::{ser, ser::SerializeTuple, Serialize, de::Visitor, Deserialize, Deserializer}; -use std::convert::TryFrom; +use std::convert::{TryInto, TryFrom}; #[derive(Debug, PartialEq)] pub struct U24(u32); @@ -54,53 +54,104 @@ impl<'de> Deserialize<'de> for U24 { } } -pub struct U256<'u256>(pub(crate) &'u256 [u8; 32]); -pub type Pubkey<'pk> = U256<'pk>; +#[derive(Debug, PartialEq)] +pub struct U256([u8; 32]); +pub type Pubkey = U256; -impl<'u256> From<&'u256 [u8; 32]> for U256<'u256> { - fn from(v: &'u256 [u8; 32]) -> Self { +impl From<[u8; 32]> for U256 { + fn from(v: [u8; 32]) -> Self { Self(v) } } -impl<'u256> From> for [u8; 32] { +impl From for [u8; 32] { fn from(v: U256) -> Self { - *v.0 + v.0 } } -impl<'u256> Serialize for U256<'u256> { +impl Serialize for U256 { fn serialize(&self, serializer: S) -> std::result::Result where S: ser::Serializer, { - serializer.serialize_bytes(self.0) + serializer.serialize_bytes(&self.0) + } +} + +struct U256Visitor; + +impl<'de> Visitor<'de> for U256Visitor { + type Value = U256; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a 32 bytes unsigned le int") + } + + fn visit_bytes(self, value: &[u8]) -> Result { + let u256: [u8; 32] = value.try_into().unwrap(); + Ok(u256.into()) + } +} + +impl<'de> Deserialize<'de> for U256 { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_newtype_struct("U256", U256Visitor) } } -pub struct Signature<'sign>(pub(crate) &'sign [u8; 64]); +#[derive(Debug, PartialEq)] +pub struct Signature([u8; 64]); -impl<'sign> From<&'sign [u8; 64]> for Signature<'sign> { - fn from(v: &'sign [u8; 64]) -> Self { +impl From<[u8; 64]> for Signature { + fn from(v:[u8; 64]) -> Self { Self(v) } } -impl<'sign> From> for [u8; 64] { +impl From for [u8; 64] { fn from(v: Signature) -> Self { - *v.0 + v.0 } } -impl<'sign> Serialize for Signature<'sign> { +impl Serialize for Signature { fn serialize(&self, serializer: S) -> std::result::Result where S: ser::Serializer, { - serializer.serialize_bytes(self.0) + serializer.serialize_bytes(&self.0) } } +struct SignatureVisitor; + +impl<'de> Visitor<'de> for SignatureVisitor { + type Value = Signature; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a 64 bytes unsigned le int") + } + + fn visit_bytes(self, value: &[u8]) -> Result { + let u256: [u8; 64] = value.try_into().unwrap(); + Ok(u256.into()) + } +} + +impl<'de> Deserialize<'de> for Signature { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_newtype_struct("Signature", SignatureVisitor) + } +} + +#[derive(Debug, PartialEq)] pub struct B016M(Vec); impl TryFrom> for B016M { @@ -125,7 +176,7 @@ impl Serialize for B016M { where S: ser::Serializer, { - let tuple = (self.0.len().to_le_bytes(), &self.0[..]); + let tuple = (&self.0.len().to_le_bytes()[0..=2], &self.0[..]); let mut seq = serializer.serialize_tuple(2)?; seq.serialize_element(&tuple.0)?; seq.serialize_element(tuple.1)?; @@ -133,6 +184,31 @@ impl Serialize for B016M { } } +struct B016MVisitor; + +impl<'de> Visitor<'de> for B016MVisitor { + type Value = B016M; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a byte array shorter than 16M") + } + + fn visit_bytes(self, value: &[u8]) -> Result { + let b0: Vec = value.into(); + Ok(b0.try_into().unwrap()) + } +} + +impl<'de> Deserialize<'de> for B016M { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_newtype_struct("B016M", B016MVisitor) + } +} + +#[derive(Debug, PartialEq)] pub struct Seq0255(Vec); pub type B0255 = Seq0255; @@ -166,7 +242,45 @@ impl Serialize for Seq0255 { } } -#[derive(Debug)] +struct Seq0255Visitor{ + _a: std::marker::PhantomData +} + +impl<'de, T: Serialize + Deserialize<'de>> Visitor<'de> for Seq0255Visitor { + type Value = Seq0255; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a general array shorter than 255") + } + + fn visit_seq>(self, mut access: A) -> Result { + let mut s: Vec = vec![]; + let mut i = 0; + while let Some(value) = access.next_element()? { + // TODO + // if i > 255 { + // return Err(Error::LenBiggerThan255) + // } + if i > 255 { + panic!() + } + s.push(value); + i += 1; + }; + Ok(s.try_into().unwrap()) + } +} + +impl<'de, T: Serialize + Deserialize<'de>> Deserialize<'de> for Seq0255 { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_newtype_struct("Seq0255", Seq0255Visitor {_a: std::marker::PhantomData}) + } +} + +#[derive(Debug, PartialEq)] pub struct Seq064K(Vec); pub type B064K = Seq064K; @@ -200,6 +314,44 @@ impl Serialize for Seq064K { } } +struct Seq064KVisitor{ + _a: std::marker::PhantomData +} + +impl<'de, T: Serialize + Deserialize<'de>> Visitor<'de> for Seq064KVisitor { + type Value = Seq064K; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a general array shorter than 64K") + } + + fn visit_seq>(self, mut access: A) -> Result { + let mut s: Vec = vec![]; + let mut i = 0; + while let Some(value) = access.next_element()? { + // TODO + // if i > 255 { + // return Err(Error::LenBiggerThan255) + // } + if i > (2_u32.pow(16)) - 1 { + panic!() + } + s.push(value); + i += 1; + }; + Ok(s.try_into().unwrap()) + } +} + +impl<'de, T: Serialize + Deserialize<'de>> Deserialize<'de> for Seq064K { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_newtype_struct("Seq064K", Seq064KVisitor {_a: std::marker::PhantomData}) + } +} + pub type Bool = bool; pub type U8 = u8; pub type U16 = u16; @@ -250,3 +402,16 @@ fn test_b0_64k_3() { Err(_) => assert!(true), } } + +#[test] +fn test_b0_16m() { + use crate::ser::to_bytes; + use std::convert::TryInto; + + let test: B016M = vec![1, 2, 9] + .try_into() + .expect("vector smaller than 64K should not fail"); + + let expected = vec![3, 0, 0, 1, 2, 9]; + assert_eq!(to_bytes(&test).unwrap(), expected); +}