diff --git a/Cargo.lock b/Cargo.lock index 50e147e6e961b4..90fec23d446f83 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1885,6 +1885,30 @@ dependencies = [ "syn 2.0.110", ] +[[package]] +name = "chacha20" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818" +dependencies = [ + "cfg-if 1.0.4", + "cipher", + "cpufeatures", +] + +[[package]] +name = "chacha20poly1305" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" +dependencies = [ + "aead", + "chacha20", + "cipher", + "poly1305", + "zeroize", +] + [[package]] name = "chrono" version = "0.4.42" @@ -1943,6 +1967,7 @@ checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" dependencies = [ "crypto-common", "inout", + "zeroize", ] [[package]] @@ -5475,6 +5500,17 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "poly1305" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf" +dependencies = [ + "cpufeatures", + "opaque-debug 0.3.0", + "universal-hash", +] + [[package]] name = "polyval" version = "0.6.2" @@ -11365,6 +11401,36 @@ dependencies = [ "x509-parser", ] +[[package]] +name = "solana-tlv" +version = "4.0.0-alpha.0" +dependencies = [ + "bencher", + "bincode", + "bytes", + "chacha20", + "chacha20poly1305", + "poly1305", + "serde", + "serde_with", + "solana-short-vec", + "thiserror 2.0.17", + "wincode", +] + +[[package]] +name = "solana-tlv-mac" +version = "4.0.0-alpha.0" +dependencies = [ + "bytes", + "chacha20", + "chacha20poly1305", + "poly1305", + "solana-short-vec", + "solana-tlv", + "wincode", +] + [[package]] name = "solana-tokens" version = "4.0.0-alpha.0" diff --git a/Cargo.toml b/Cargo.toml index f71d76785e5b6d..78e41020273ee5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -113,6 +113,8 @@ members = [ "test-validator", "thread-manager", "tls-utils", + "tlv", + "tlv/tlv-mac", "tokens", "tps-client", "tpu-client", @@ -566,6 +568,7 @@ solana-sysvar-id = "3.0.0" solana-test-validator = { path = "test-validator", version = "=4.0.0-alpha.0" } solana-time-utils = "3.0.0" solana-tls-utils = { path = "tls-utils", version = "=4.0.0-alpha.0", features = ["agave-unstable-api"] } +solana-tlv = { path = "tlv", version = "=4.0.0-alpha.0" } solana-tps-client = { path = "tps-client", version = "=4.0.0-alpha.0", features = ["agave-unstable-api"] } solana-tpu-client = { path = "tpu-client", version = "=4.0.0-alpha.0", default-features = false, features = ["agave-unstable-api"] } solana-tpu-client-next = { path = "tpu-client-next", version = "=4.0.0-alpha.0", features = ["agave-unstable-api"] } diff --git a/tlv/Cargo.toml b/tlv/Cargo.toml new file mode 100644 index 00000000000000..441ef221a01bcb --- /dev/null +++ b/tlv/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "solana-tlv" +description = "Solana TLV implementation" +documentation = "https://docs.rs/solana-tlv" +version = { workspace = true } +authors = { workspace = true } +repository = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +edition = { workspace = true } +publish = false + +[features] +agave-unstable-api = [] + +[dependencies] +bytes = { workspace = true } +chacha20 = "0.9.1" +chacha20poly1305 = { version = "0.10.1" } +poly1305 = "0.8.0" +thiserror = { workspace = true } +wincode = { workspace = true, features = ["derive"] } +solana-short-vec = { workspace = true } + +[dev-dependencies] +bencher = { workspace = true } +bincode = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_with = { workspace = true, features = ["macros"] } +[lints] +workspace = true + +[[bench]] +name = "tlv_vs_wincode" +harness = false diff --git a/tlv/README.md b/tlv/README.md new file mode 100644 index 00000000000000..ca1ce3b0b24e89 --- /dev/null +++ b/tlv/README.md @@ -0,0 +1,64 @@ +## Tag-Length-Value data support for Solana + +TLV (Type Length Value) is a well-established format to encode binary data on the wire, offering major advantages compared to most alternatives: +1. Ability to evolve existing protocols without hard version switch +2. Efficient parsing and serialization +3. Perfect forward compatibility + +This is somewhat similar to protobuf, except that the receiver does not need to be +able to parse all of the records to be able to read the others. + +## Wire format + +A packet consists of a sequence of byte-aligned records. Each record contains: +* tag:u8 - 1 byte, can not be zero +* length:u16 - 1-3 bytes on the wire (1 byte if less than 127 bytes, uses solana-short-vec impl) +* value - 1..MTU bytes + +The records can be appended as needed to form compound packets. As a safety precaution, +maximal size of any value is capped at MAX_VALUE_LENGTH = 1500 bytes. + +## Defining enums + +Any rust enum can be turned into a TLV compatible encoding with a macro: +```rust +use solana_tlv::{define_tlv_enum, signature::Signature}; +use bytes::Bytes; + +define_tlv_enum! (pub(crate) enum Extension { + 1=>Thing(u64), // this will use bincode + 3=>DoGood(()), // this will store the tag and no data + 4=>Mac(Signature<16>), // and this allows to sign packets + #[raw] + 5=>ByteArray(Bytes), // this will get bytes included verbatim +}); +``` + +Variant tags must be unique. Reusing them causes parsing errors. + +Intended workflow: +```rust +use bytes::{Bytes, BytesMut}; + +let tlv_data = vec![ + Extension::Thing(42), + Extension::ByteArray(Bytes::from(vec![77u8; 256])), +]; +let mut buffer = BytesMut::with_capacity(2000); +serialize_into_buffer(&tlv_data, &mut buffer).unwrap(); +// send buffer over the wire +let recovered_data: Vec = deserialize_from_buffer(buffer.freeze()).collect(); +``` + +## Signatures + +A `solana_tlv_mac::Signature` entry can be attached to the end of a message to sign the whole message. + +## Performance + +This crate has not been heavily optimized and likely has room for further improvement + +## Caveats + +Since the `define_tlv_enum` is a macro, you need to include serde into dependencies of any crate using the +macro. diff --git a/tlv/benches/tlv_vs_wincode.rs b/tlv/benches/tlv_vs_wincode.rs new file mode 100644 index 00000000000000..3b9b370cd73f12 --- /dev/null +++ b/tlv/benches/tlv_vs_wincode.rs @@ -0,0 +1,179 @@ +#![allow(clippy::arithmetic_side_effects)] + +#[macro_use] +extern crate bencher; + +use { + bencher::Bencher, + /*bytes::BytesMut, + serde_with::serde_as, + solana_short_vec as short_vec, + solana_tlv::*, + std::mem::MaybeUninit, + wincode::{ + containers::{self, Pod}, + io, + len::ShortU16Len, + SchemaRead, + }, + wincode_derive::{SchemaRead, SchemaWrite},*/ +}; + +fn tlv_roundtrip(_slot_num: u64) { + /* use serde::{Deserialize, Serialize}; + + #[serde_as] + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] + struct Finalize { + pubkey: [u8; 32], + #[serde_as(as = "[_; 96]")] + bls_signature: [u8; 96], + } + + #[serde_as] + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] + struct NotarizeCert { + #[serde_as(as = "[_; 96]")] + bls_signature: [u8; 96], + #[serde(with = "short_vec")] + bitmap: Vec, + } + + define_tlv_enum! (pub(crate) enum AlpenglowVotor { + 1=>Slot(u64), + 10=>Finalize(Finalize), + 11=>NotarizeCert(NotarizeCert), + }); + let notar_cert = NotarizeCert { + bitmap: vec![42u8; 2000 / 8], + bls_signature: [7; 96], + }; + let final_vote = Finalize { + pubkey: [3; 32], + bls_signature: [7; 96], + }; + + // allocate space for a packet and fill it with data + let mut buffer = BytesMut::with_capacity(1200); + let entries = [ + AlpenglowVotor::Slot(slot_num), + AlpenglowVotor::Finalize(final_vote), + AlpenglowVotor::NotarizeCert(notar_cert), + ]; + serialize_into_buffer(&entries, &mut buffer).unwrap(); + + let buffer = buffer.freeze(); + let mut recovered = vec![]; + for (_size, maybe_record) in TlvIter::new(buffer) { + let maybe_record: Result = maybe_record.try_into(); + let record = maybe_record.unwrap(); + recovered.push(record); + } + assert_eq!(entries.as_slice(), recovered.as_slice()); + */ +} + +fn wincode_roundtrip(_slot_num: u64) { + /* + #[derive(Debug, Clone, PartialEq, Eq, SchemaWrite, SchemaRead)] + struct Finalize { + pubkey: [u8; 32], + bls_signature: [u8; 96], + } + + #[derive(Debug, Clone, PartialEq, Eq, SchemaWrite, SchemaRead)] + struct NotarizeCert { + bls_signature: [u8; 96], + #[wincode(with = "containers::Vec, ShortU16Len>")] + bitmap: Vec, + } + #[derive(Clone, Debug, Eq, PartialEq, SchemaWrite, SchemaRead)] + pub(crate) struct TlvRecord { + // type + pub(crate) typ: u8, + // length and serialized bytes of the value + #[wincode(with = "containers::Vec, ShortU16Len>")] + pub(crate) bytes: Vec, + } + + #[derive(Debug, Eq, PartialEq)] + enum AlpenglowVotor { + Slot(u64), + Finalize(Finalize), + NotarizeCert(NotarizeCert), + } + let notar_cert = NotarizeCert { + bitmap: vec![42u8; 2000 / 8], + bls_signature: [7; 96], + }; + let final_vote = Finalize { + pubkey: [3; 32], + bls_signature: [7; 96], + }; + let entries = [ + AlpenglowVotor::Slot(slot_num), + AlpenglowVotor::Finalize(final_vote), + AlpenglowVotor::NotarizeCert(notar_cert), + ]; + let mut buffer = BytesMut::with_capacity(1200); + + for e in entries.iter() { + let (typ, val) = match e { + AlpenglowVotor::Slot(slot) => (1, wincode::serialize(&slot)), + AlpenglowVotor::Finalize(finalize) => (10, wincode::serialize(&finalize)), + AlpenglowVotor::NotarizeCert(notiarize_cert) => { + (11, wincode::serialize(¬iarize_cert)) + } + }; + let val = val.unwrap(); + let tlv = TlvRecord { typ, bytes: val }; + + let len = wincode::serialize_into(&tlv, buffer.spare_capacity_mut()).unwrap(); + unsafe { + buffer.set_len(buffer.len() + len); + } + } + + let mut buffer = io::Reader::new(&buffer); + + let mut recovered = vec![]; + loop { + let mut tlv_record: MaybeUninit = MaybeUninit::uninit(); + + let read_res = TlvRecord::read(&mut buffer, &mut tlv_record); + if read_res.is_err() { + break; + } else { + let tlv_record = unsafe { tlv_record.assume_init() }; + let payload = match tlv_record.typ { + 1 => AlpenglowVotor::Slot(wincode::deserialize(&tlv_record.bytes).unwrap()), + 10 => AlpenglowVotor::Finalize(wincode::deserialize(&tlv_record.bytes).unwrap()), + 11 => { + AlpenglowVotor::NotarizeCert(wincode::deserialize(&tlv_record.bytes).unwrap()) + } + _ => panic!(), + }; + recovered.push(payload); + } + } + assert_eq!(&recovered, &entries); + */ +} +fn tlv(bench: &mut Bencher) { + let mut counter = 0; + bench.iter(|| { + tlv_roundtrip(counter); + counter += 1; + }) +} + +fn wincode(bench: &mut Bencher) { + let mut counter = 0; + bench.iter(|| { + wincode_roundtrip(counter); + counter += 1; + }); +} + +benchmark_group!(benches, tlv, wincode); +benchmark_main!(benches); diff --git a/tlv/src/lib.rs b/tlv/src/lib.rs new file mode 100644 index 00000000000000..2d79aff384ef8a --- /dev/null +++ b/tlv/src/lib.rs @@ -0,0 +1,213 @@ +use crate::lv_record::LvPayload; + +use { + bytes::{Bytes, BytesMut}, + wincode::{SchemaRead, SchemaWrite}, +}; + +mod lv_record; +mod short_u16; +pub use lv_record::{TlvDecodeError, TlvEncodeError, TlvRecord}; +use wincode::io::Cursor; + +pub use lv_record::LvRecord; +pub use lv_record::TlvSchemaReader; + +/// Walks TLV records in a buffer without parsing them. +/// +/// This iterator is not guaranteed to walk the entirety of the provided +/// buffer, and will instead stop on the first invalid entry. +pub struct TlvIter +where + T: AsRef<[u8]>, +{ + entries: Cursor, +} + +impl TlvIter +where + T: AsRef<[u8]>, +{ + /// Construct an iterator over Bytes object holding serialized TlvRecords + /// + /// This iterator is not guaranteed to walk the entirety of the provided + /// buffer, and will instead stop on the first invalid entry. + pub fn new(entries: T) -> Self { + Self { + entries: Cursor::new(entries), + } + } +} + +impl Iterator for TlvIter +where + T: AsRef<[u8]>, +{ + type Item = LvRecord; + /// Consume next item from the iterator. + /// If this returns None, no more valid items can be read from the buffer. + fn next(&mut self) -> Option { + let res: Result, _> = TlvSchemaReader::get(&mut self.entries); + res.ok().flatten() + } +} + +/// Serialize all entries into given buffer. +/// Buffer must have preallocated memory. +pub fn serialize_into_buffer<'a, T: 'a + SchemaWrite, B>( + entries: &'a [T], + buffer: &mut Cursor, +) -> Result<(), TlvEncodeError> { + for entry in entries { + wincode::serialize_into(buffer, &entry)?; + } + Ok(()) +} + +/// Walk over a given buffer returning deserialized TLV items +/// This will quietly skip all invalid entries. +pub fn deserialize_from_buffer>( + buffer: Bytes, +) -> impl Iterator { + TlvIter::new(buffer).filter_map(|v| v.try_into().ok()) +} + +#[cfg(test)] +mod tests { + use { + crate::{deserialize_from_buffer, serialize_into_buffer, LvRecord}, + bytes::BytesMut, + serde::Serialize, + solana_short_vec::decode_shortu16_len, + wincode::io::Cursor, + }; + #[derive(Debug, wincode::SchemaRead, wincode::SchemaWrite, PartialEq, Eq)] + enum ExtensionNew { + #[wincode(tag = 1)] + Test(LvRecord), + #[wincode(tag = 2)] + LegacyString(LvRecord), + #[wincode(tag = 3)] + NewString(LvRecord), + #[wincode(tag = 6)] + NewEmptyTag(()), + } + + #[derive(Debug, wincode::SchemaRead, wincode::SchemaWrite, PartialEq, Eq)] + enum ExtensionLegacy { + #[wincode(tag = 1)] + Test(LvRecord), + #[wincode(tag = 2)] + LegacyString(LvRecord), + } + + /// Test that TLV encoded data is backwards-compatible, + /// i.e. that new TLV data can be decoded by a new + /// receiver where possible, and skipped otherwise + #[test] + fn test_tlv_backwards_compat() { + let new_tlv_data = [ + ExtensionNew::Test(LvRecord::new(42)), + ExtensionNew::NewString(LvRecord::new(String::from("bla"))), + ExtensionNew::NewEmptyTag(()), + ]; + let mut buffer = BytesMut::with_capacity(2000); + serialize_into_buffer(&new_tlv_data, &mut buffer).unwrap(); + + let buffer = buffer.freeze(); + // check that both TLV are encoded correctly + let new_recovered: Vec = deserialize_from_buffer(buffer.clone()).collect(); + assert_eq!(new_recovered[0], ExtensionNew::Test(LvRecord::new(42))); + if let ExtensionNew::NewString(s) = &new_recovered[1] { + assert_eq!(s.payload(), "bla"); + } else { + panic!("Wrong deserialization") + }; + + // Make sure legacy recover works correctly + let legacy_recovered: Vec = + deserialize_from_buffer(buffer.clone()).collect(); + assert_eq!( + legacy_recovered[0], + ExtensionLegacy::Test(LvRecord::new(42)) + ); + assert_eq!( + legacy_recovered.len(), + 1, + "Legacy parser should only recover 1 entry" + ) + } + + /// Test that TLV encoded data is forwards-compatible, + /// i.e. that legacy TLV data can be decoded by a new + /// receiver + #[test] + fn test_tlv_forward_compat() { + let legacy_tlv_data = [ + ExtensionLegacy::Test(LvRecord::new(42)), + ExtensionLegacy::LegacyString(LvRecord::new(String::from("foo"))), + ]; + let mut buffer = Vec::with_capacity(2000); + serialize_into_buffer(&legacy_tlv_data, &mut Cursor(buffer)).unwrap(); + + // Parse the same bytes using new parser + let new_recovered: Vec = deserialize_from_buffer(buffer.clone()).collect(); + assert_eq!(new_recovered[0], ExtensionNew::Test(LvRecord::new(42))); + if let ExtensionNew::LegacyString(s) = &new_recovered[1] { + assert_eq!(s.payload(), "foo"); + } else { + panic!("Wrong deserialization") + }; + } + + // #[test] + // fn test_tlv_wire_format() { + // let tlv_data = vec![ + // ExtensionNew::Test(LvRecord(u64::MAX)), + // ExtensionNew::ByteArray(vec![77u8; 256].into_boxed_slice()), + // ]; + // let mut buffer = BytesMut::with_capacity(2000); + // serialize_into_buffer(&tlv_data, &mut buffer).unwrap(); + // let field_1 = &buffer[0..10]; + // assert_eq!(field_1[0], 1, "tag of first field should be 1"); + // assert_eq!(field_1[1], 8, "length of first field should be 8"); + // assert_eq!( + // field_1[2..10], + // u64::MAX.to_le_bytes(), + // "Value of first field should be u64::MAX" + // ); + // let field_2 = &buffer[10..]; + // assert_eq!(field_2[0], 5, "tag of second field should be 5"); + // assert_eq!( + // decode_shortu16_len(&field_2[1..3]).unwrap(), + // (256, 2), + // "length of second field should be 256" + // ); + // assert_eq!(field_2[3], 77, "Value of second field should be 77"); + // assert_eq!(field_2[256 + 2], 77, "Value of second field should be 77"); + + // let recovered_data: Vec = deserialize_from_buffer(buffer.freeze()).collect(); + // assert_eq!(recovered_data, tlv_data) + // } + + // checks that we are wire-compatible with gossip TLV impl + #[test] + fn test_abi_wire_compat_gossip() { + use solana_short_vec as short_vec; + #[derive(Serialize)] + struct TlvRecord { + typ: u8, // type + #[serde(with = "short_vec")] + bytes: Vec, // length and value + } + let tlv_data = vec![ExtensionNew::Test(LvRecord::new(u64::MAX))]; + let mut buffer: Vec = Vec::with_capacity(50); + serialize_into_buffer(&tlv_data, &mut Cursor::new(&mut buffer)).unwrap(); + let rec = TlvRecord { + typ: 1, + bytes: vec![255, 255, 255, 255, 255, 255, 255, 255], + }; + let bincode_vec = bincode::serialize(&rec).unwrap(); + assert_eq!(&bincode_vec, &buffer); + } +} diff --git a/tlv/src/lv_record.rs b/tlv/src/lv_record.rs new file mode 100644 index 00000000000000..96739686b865b4 --- /dev/null +++ b/tlv/src/lv_record.rs @@ -0,0 +1,175 @@ +use { + crate::{short_u16::ShortU16, TlvSerialize, MAX_VALUE_LENGTH}, + bytes::BytesMut, + std::marker::PhantomData, + wincode::{ + containers, io::Reader, len::ShortU16Len, ReadResult, SchemaRead, SchemaWrite, WriteError, + WriteResult, + }, +}; + +/// Marker trait for payloads, eligible for use in [struct LvRecord]. +/// Automatically implemented for all eligible types. +pub trait LvPayload: SchemaWrite + for<'a> SchemaRead<'a, Dst = Self> {} +impl LvPayload for T +where + T: SchemaWrite, + for<'a> T: SchemaRead<'a, Dst = T>, +{ +} + +/// A Length-Value encoded entry. Can be skipped by receivers that do not +/// have a parser for a specifc T. Used to construct entries in TLV enums: +/// ```rust +/// #[derive(Debug, wincode::SchemaRead, wincode::SchemaWrite)] +/// enum MyEnum { +/// #[wincode(tag = 1)] +/// A (LvRecord ), +/// #[wincode(tag = 2)] +/// A (LvRecord ), +/// } +/// ``` +#[derive(Clone, Debug, Eq, PartialEq, SchemaRead, SchemaWrite)] +pub struct LvRecord { + payload_len: ShortU16, + // direct access to payload is blocked to ensure payload_len is always in sync + payload: T, +} + +const MAX_VALUE_LENGTH: usize = 1280; + +impl LvRecord { + pub fn try_new(value: T) -> WriteResult { + let length = wincode::serialized_size(&value)? + .try_into() + .map_err(|e| WriteError::LengthEncodingOverflow("Value too long"))?; + if length > MAX_VALUE_LENGTH { + return Err(WriteError::LengthEncodingOverflow("Value too long")); + } + Ok(Self { + payload_len: ShortU16(length), + payload: value, + }) + } + + /// simpler constructor for unittests + #[cfg(test)] + pub fn new(value: T) -> Self { + Self::try_new(value).expect("Construction should succeed") + } + + pub fn payload(&self) -> &T { + return self.payload; + } +} + +/// A high level interface to perform TLV parsing via wincode. +/// +/// ```rust +/// #[derive(Debug, wincode::SchemaRead, wincode::SchemaWrite)] +/// enum MyEnum { +/// #[wincode(tag = 3)] +/// A (LvRecord ), +/// } +/// let message = MyEnum::A ( +/// LvRecord::try_new(42).unwrap(), +/// ); +/// let buf = wincode::serialize(&message).unwrap(); +/// let res: Result, _> = TlvSchemaReader::get(&mut buf); +/// ``` +#[derive(Debug)] +pub struct TlvSchemaReader(PhantomData); + +impl<'de, T: SchemaRead<'de, Dst = T>> SchemaRead<'de> for TlvSchemaReader { + type Dst = Option; + + fn read( + reader: &mut impl Reader<'de>, + dst: &mut std::mem::MaybeUninit, + ) -> wincode::ReadResult<()> { + match T::get(reader) { + Ok(x) => { + dst.write(Some(x)); + Ok(()) + } + Err(wincode::ReadError::InvalidTagEncoding(_)) => { + let length = ShortU16::get(reader)?; + if length > MAX_VALUE_LENGTH { + return Err(wincode::ReadError::LengthEncodingOverflow("Value too long")); + } + reader.consume(length.into())?; + dst.write(None); + Ok(()) + } + Err(e) => Err(e), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum TlvDecodeError { + #[error("Invalid tag: {0}")] + InvalidTag(u8), + #[error("Malformed payload: {0}")] + MalformedPayload(#[from] wincode::ReadError), +} + +#[derive(Debug, thiserror::Error)] +pub enum TlvEncodeError { + #[error("Invalid tag: {0}")] + InvalidTag(u8), + #[error("Not enough space for payload in the buffer")] + NotEnoughSpace, + #[error("Payload exceeds MAX_VALUE_LENGTH")] + PayloadTooBig, + #[error("Malformed payload: {0}")] + MalformedPayload(#[from] wincode::WriteError), +} + +#[cfg(test)] +mod tests { + use { + crate::{ + lv_record::{LvRecord, TlvSchemaReader}, + short_u16::ShortU16, + TlvEncodeError, + }, + bytes::BytesMut, + std::mem::{transmute, MaybeUninit}, + wincode::{ + error::{read_length_encoding_overflow, write_length_encoding_overflow}, + io::Reader, + ReadResult, SchemaRead, SchemaWrite, WriteError, WriteResult, + }, + }; + + #[derive(Debug, wincode::SchemaRead, wincode::SchemaWrite)] + enum MyEnum { + #[wincode(tag = 1)] + A(LvRecord), + #[wincode(tag = 2)] + B(LvRecord), + } + + #[derive(Debug, wincode::SchemaRead, wincode::SchemaWrite)] + enum MyEnum2 { + #[wincode(tag = 3)] + C(LvRecord), + } + + #[test] + fn test_zach() { + let message1 = MyEnum::A(LvRecord::try_new(777).unwrap()); + let message2 = MyEnum2::C(LvRecord::try_new(42).unwrap()); + + let buf1 = wincode::serialize(&message1).unwrap(); + let buf2 = wincode::serialize(&message2).unwrap(); + dbg!(&buf1, &buf2); + let res: Result, _> = TlvSchemaReader::get(&mut buf1.as_slice()); + + dbg!(res); + let res: Result, _> = TlvSchemaReader::get(&mut buf1.as_slice()); + + dbg!(res); + } +} diff --git a/tlv/src/short_u16.rs b/tlv/src/short_u16.rs new file mode 100644 index 00000000000000..ed069f3d6be39e --- /dev/null +++ b/tlv/src/short_u16.rs @@ -0,0 +1,105 @@ +use { + core::ptr, + solana_short_vec::decode_shortu16_len, + std::mem::{transmute, MaybeUninit}, + wincode::{ + error::{read_length_encoding_overflow, write_length_encoding_overflow}, + io::Reader, + ReadResult, SchemaRead, SchemaWrite, WriteResult, + }, +}; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[repr(transparent)] +pub struct ShortU16(pub u16); + +impl From for usize { + fn from(value: ShortU16) -> Self { + value.0 as usize + } +} + +/// Branchless computation of the number of bytes needed to encode a short u16. +/// +/// See [`solana_short_vec::ShortU16`] for more details. +#[inline(always)] +#[allow(clippy::arithmetic_side_effects)] +fn short_u16_bytes_needed(len: u16) -> usize { + 1 + (len >= 0x80) as usize + (len >= 0x4000) as usize +} + +#[inline(always)] +fn try_short_u16_bytes_needed>(len: T) -> WriteResult { + match len.try_into() { + Ok(len) => Ok(short_u16_bytes_needed(len)), + Err(_) => Err(write_length_encoding_overflow("u16::MAX")), + } +} +/// Encode a short u16 into the given buffer. +/// +/// See [`solana_short_vec::ShortU16`] for more details. +/// +/// # Safety +/// +/// - `dst` must be a valid for writes. +/// - `dst` must be valid for `needed` bytes. +#[inline(always)] +unsafe fn encode_short_u16(dst: *mut u8, needed: usize, len: u16) { + // From `solana_short_vec`: + // + // u16 serialized with 1 to 3 bytes. If the value is above + // 0x7f, the top bit is set and the remaining value is stored in the next + // bytes. Each byte follows the same pattern until the 3rd byte. The 3rd + // byte may only have the 2 least-significant bits set, otherwise the encoded + // value will overflow the u16. + unsafe { + match needed { + 1 => ptr::write(dst, len as u8), + 2 => { + ptr::write(dst, ((len & 0x7f) as u8) | 0x80); + ptr::write(dst.add(1), (len >> 7) as u8); + } + 3 => { + ptr::write(dst, ((len & 0x7f) as u8) | 0x80); + ptr::write(dst.add(1), (((len >> 7) & 0x7f) as u8) | 0x80); + ptr::write(dst.add(2), (len >> 14) as u8); + } + _ => unreachable!(), + } + } +} + +impl<'de> SchemaRead<'de> for ShortU16 { + type Dst = ShortU16; + + fn read( + reader: &mut impl Reader<'de>, + dst: &mut std::mem::MaybeUninit, + ) -> ReadResult<()> { + let Ok((len, read)) = decode_shortu16_len(reader.fill_buf(3)?) else { + return Err(read_length_encoding_overflow("u16::MAX")); + }; + unsafe { reader.consume_unchecked(read) }; + Ok(()) + } +} + +impl SchemaWrite for ShortU16 { + type Src = ShortU16; + + fn size_of(src: &Self::Src) -> wincode::WriteResult { + try_short_u16_bytes_needed(src.0) + } + + fn write(writer: &mut impl wincode::io::Writer, src: &Self::Src) -> wincode::WriteResult<()> { + let src = src.0; + let needed = short_u16_bytes_needed(src); + let mut buf = [MaybeUninit::::uninit(); 3]; + // SAFETY: short_u16 uses a maximum of 3 bytes, so the buffer is always large enough. + unsafe { encode_short_u16(buf.as_mut_ptr().cast(), needed, src) }; + // SAFETY: encode_short_u16 writes exactly `needed` bytes. + let buf = unsafe { transmute::<&[MaybeUninit], &[u8]>(buf.get_unchecked(..needed)) }; + writer.write(buf)?; + Ok(()) + } +} diff --git a/tlv/tlv-mac/Cargo.toml b/tlv/tlv-mac/Cargo.toml new file mode 100644 index 00000000000000..42b5ba2db1b94c --- /dev/null +++ b/tlv/tlv-mac/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "solana-tlv-mac" +edition = "2024" +description = "Solana TLV message authentication code (MAC) signature support" +documentation = "https://docs.rs/solana-tlv-mac" +version = { workspace = true } +authors = { workspace = true } +repository = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +publish = false + +[features] +agave-unstable-api = [] + +[dependencies] +chacha20 = "0.9.1" +chacha20poly1305 = { version = "0.10.1" } +poly1305 = "0.8.0" +wincode = { workspace = true, features = ["derive"] } + +[dev-dependencies] +bytes = { workspace = true } +solana-short-vec = { workspace = true } +solana-tlv = { workspace = true } + +[lints] +workspace = true diff --git a/tlv/tlv-mac/README.md b/tlv/tlv-mac/README.md new file mode 100644 index 00000000000000..a211f184d17faa --- /dev/null +++ b/tlv/tlv-mac/README.md @@ -0,0 +1,17 @@ +## Tag-Length-Value Signatures Solana + +Provides ability to sign the TLV-encoded packets with poly1305 symmetric key signatures. + +## Wire format + +The signature can include anywhere from 1 to 16 bytes, and is configurable +via const generic parameter. Default is 16 bytes. + +The signature is intended to operate on a UDP pseudoheader, as well as packet payload, +to ensure that injection attacks can not be executed. + +Example of the usage for alpenglow voting is available in `tests/vote.rs` + +## Performance + +This crate has not been heavily optimized and likely has room for further improvement diff --git a/tlv/tlv-mac/src/lib.rs b/tlv/tlv-mac/src/lib.rs new file mode 100644 index 00000000000000..8b4adf23e2b750 --- /dev/null +++ b/tlv/tlv-mac/src/lib.rs @@ -0,0 +1,161 @@ +use { + chacha20::{ + ChaCha20, + cipher::{KeyIvInit, StreamCipher}, + }, + chacha20poly1305::aead::KeyInit, + poly1305::{Poly1305, universal_hash::UniversalHash}, + std::{net::SocketAddr, ops::Deref}, + wincode::{SchemaRead, SchemaWrite}, +}; + +/// marker trait for valid signature sizes +pub trait _ValidSigSize {} //TODO: turn this into a const generic assert once it stabilizes +// can be anything from 1 to 16 bytes in practice, feel free to expand +impl _ValidSigSize for _ConstUsize<2> {} +impl _ValidSigSize for _ConstUsize<4> {} +impl _ValidSigSize for _ConstUsize<8> {} +impl _ValidSigSize for _ConstUsize<16> {} + +pub struct _ConstUsize; +impl _ConstUsize {} + +/// Poly 1305 signature for a packet. +/// Can be truncated from default of 16 bytes down as needed. +#[derive(SchemaWrite, SchemaRead, Debug, PartialEq, Eq, Clone, Copy)] +#[repr(transparent)] +pub struct Signature { + signature: [u8; N], +} + +impl Signature +where + _ConstUsize: _ValidSigSize, +{ + /// Fill UDP pseudoheader based on packet contents. + /// Panics if given incompatible address family addresses. + fn fill_udp_pseudoheader( + src: SocketAddr, + dst: SocketAddr, + pseudoheader: &mut [u8; 36], + ) -> &[u8] { + match (src, dst) { + (SocketAddr::V4(src), SocketAddr::V4(dst)) => { + pseudoheader[0..4].copy_from_slice(&src.ip().octets()); + pseudoheader[4..6].copy_from_slice(&src.port().to_be_bytes()); + pseudoheader[6..10].copy_from_slice(&dst.ip().octets()); + pseudoheader[10..12].copy_from_slice(&dst.port().to_be_bytes()); + &pseudoheader[0..12] + } + (SocketAddr::V6(src), SocketAddr::V6(dst)) => { + pseudoheader[0..16].copy_from_slice(&src.ip().octets()); + pseudoheader[16..18].copy_from_slice(&src.port().to_be_bytes()); + pseudoheader[18..34].copy_from_slice(&dst.ip().octets()); + pseudoheader[34..36].copy_from_slice(&dst.port().to_be_bytes()); + pseudoheader.as_slice() + } + _ => { + debug_assert!(false, "Can not mix v4 and v6 addresses in one packet"); + &pseudoheader[0..0] + } + } + } + + /// Compute a Poly1305 MAC with a one-time key derived from ChaCha20. + /// This will truncate the 16-byte MAC to the required length + /// - src: source address + /// - dst: destination address + /// - key: 32-byte pre-shared key + /// - nonce: 12-byte IETF nonce (unique per message under a given key) + /// - msg: message bytes + pub fn new_poly1305_for_udp( + src: SocketAddr, + dst: SocketAddr, + key: &[u8; 32], + nonce: &[u8; 12], + msg: &[u8], + ) -> Self { + let mut pseudoheader = [0u8; 36]; // fit both v6 and v4 + let pseudoheader = Self::fill_udp_pseudoheader(src, dst, &mut pseudoheader); + Self { + signature: *chacha20_poly1305_mac(key, nonce, &[pseudoheader, msg]) + .first_chunk() + .expect("This operation is infallible for any N<=16 as enforced by type system"), + } + } +} + +impl Deref for Signature +where + _ConstUsize: _ValidSigSize, +{ + type Target = [u8; N]; + + fn deref(&self) -> &Self::Target { + &self.signature + } +} + +/// Compute a Poly1305 MAC with a one-time key derived from ChaCha20. +/// - key: 32-byte pre-shared key +/// - nonce: 12-byte IETF nonce (unique per message under a given key) +/// - msg: message to authenticate +fn chacha20_poly1305_mac(key: &[u8; 32], nonce: &[u8; 12], blocks: &[&[u8]]) -> [u8; 16] { + // 1) Derive one-time Poly1305 key = ChaCha20(key, nonce, counter=0) + let mut cipher: ChaCha20 = KeyIvInit::new(key.into(), nonce.into()); + let mut otk = [0u8; 32]; + cipher.apply_keystream(&mut otk); // keystream over zeros to get OTK + + // 2) Compute Poly1305 tag over the message using the one-time key + let mut poly = Poly1305::new((&otk).into()); + for &block in blocks.iter() { + poly.update_padded(block); + } + let tag = poly.finalize(); + + let mut out = [0u8; 16]; + out.copy_from_slice(tag.as_slice()); + out +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mac_roundtrip() { + let mut key = [0x11u8; 32]; + let nonce = [0x22u8; 12]; + let mut msg = [1u8; 42]; + let tag = chacha20_poly1305_mac(&key, &nonce, &[&msg]); + let tag2 = chacha20_poly1305_mac(&key, &nonce, &[&msg[..32], &msg[32..]]); + assert_eq!(tag, tag2); + // negative tests + msg[2] ^= 1; + let bad = chacha20_poly1305_mac(&key, &nonce, &[&msg]); + assert_ne!(bad, tag); + msg[2] ^= 1; + key[2] ^= 1; + let bad = chacha20_poly1305_mac(&key, &nonce, &[&msg]); + assert_ne!(bad, tag); + } + #[test] + fn mac_udp_roundtrip() { + let key = &[0x11u8; 32]; + let nonce = &[0x22u8; 12]; + let msg = &mut [1u8; 42]; + let src1 = "1.2.3.4:8833".parse().unwrap(); + let src2 = "1.2.3.4:8822".parse().unwrap(); + let dst1 = "1.2.3.4:8833".parse().unwrap(); + let signature1: Signature<8> = Signature::new_poly1305_for_udp(src1, dst1, key, nonce, msg); + let signature2 = Signature::new_poly1305_for_udp(src1, dst1, key, nonce, msg); + assert_eq!(signature1, signature2); + + // negative tests + let signature2 = Signature::new_poly1305_for_udp(src2, dst1, key, nonce, msg); + assert_ne!(signature1, signature2); + msg[2] ^= 1; + let signature2 = Signature::new_poly1305_for_udp(src1, dst1, key, nonce, msg); + assert_ne!(signature1, signature2); + } +} diff --git a/tlv/tlv-mac/tests/vote.rs b/tlv/tlv-mac/tests/vote.rs new file mode 100644 index 00000000000000..4ad0b93141b7a8 --- /dev/null +++ b/tlv/tlv-mac/tests/vote.rs @@ -0,0 +1,119 @@ +#![allow(clippy::arithmetic_side_effects)] +use { + bytes::{Bytes, BytesMut}, + solana_tlv::*, + solana_tlv_mac::Signature, + std::net::SocketAddr, + wincode::{SchemaRead, SchemaWrite, containers, len::ShortU16Len}, +}; + +#[derive(Debug, Clone, SchemaRead, SchemaWrite, PartialEq, Eq)] +struct Finalize { + pubkey: [u8; 32], + bls_signature: [u8; 96], +} + +#[derive(Debug, Clone, SchemaRead, SchemaWrite, PartialEq, Eq)] +struct NotarizeCert { + bls_signature: [u8; 96], + #[wincode(with = "containers::Vec")] + bitmap: Vec, +} + +define_tlv_enum! (pub(crate) enum AlepnglowVotor { + 1=>Nonce(u64), + 2=>Mac(Signature<16>), + 10=>Finalize(Finalize), + 11=>NotarizeCert(NotarizeCert), +}); + +fn main() { + let notar_cert = NotarizeCert { + bitmap: vec![42u8; 2000 / 8], + bls_signature: [7; 96], + }; + let final_vote = Finalize { + pubkey: [3; 32], + bls_signature: [7; 96], + }; + let ag_nonce = 1231244; + + // allocate space for a packet and fill it with data + let mut buffer = BytesMut::with_capacity(1200); + let entries = [ + AlepnglowVotor::Nonce(ag_nonce), + AlepnglowVotor::Finalize(final_vote), + AlepnglowVotor::NotarizeCert(notar_cert), + ]; + serialize_into_buffer(&entries, &mut buffer).unwrap(); + + // sign packet + let src: SocketAddr = "1.2.3.4:8888".parse().unwrap(); + let dst: SocketAddr = "5.6.7.8:8888".parse().unwrap(); + let key = [1; 32]; + let mut nonce = [0u8; 12]; + nonce[0..8].copy_from_slice(&ag_nonce.to_be_bytes()); + let signature: Signature<16> = + Signature::new_poly1305_for_udp(src, dst, &key, &nonce, &buffer[..buffer.len()]); + // write signature into the packet + serialize_into_buffer(&[AlepnglowVotor::Mac(signature)], &mut buffer).unwrap(); + let buffer = buffer.freeze(); + + let entries_rx = decode_and_verify_signature(src, dst, key, buffer.clone()).unwrap(); + assert_eq!( + entries_rx[0..=2], + entries, + "The original entries should match up" + ); + + // try replaying a valid message to a wrong address + let dst2: SocketAddr = "1.2.5.4:8888".parse().unwrap(); + + assert_eq!( + decode_and_verify_signature(src, dst2, key, buffer).unwrap_err(), + "Invalid packet!" + ); +} + +fn decode_and_verify_signature( + src: SocketAddr, + dst: SocketAddr, + key: [u8; 32], + buffer: Bytes, +) -> Result, String> { + // decode and verify the signature + let mut recovered = Vec::new(); + let mut signed_portion = 0; + let mut nonce = [0u8; 12]; + for maybe_record in TlvIter::new(buffer.clone()) { + let size = maybe_record.serialized_size(); + let maybe_record: Result = maybe_record.try_into(); + let record = match maybe_record { + Ok(record) => record, + Err(e) => Err(e.to_string())?, + }; + + match record { + AlepnglowVotor::Mac(signature) => { + let correct_signature: Signature<16> = Signature::new_poly1305_for_udp( + src, + dst, + &key, + &nonce, + &buffer[..signed_portion], + ); + if signature != correct_signature { + return Err("Invalid packet!".to_owned()); + } + recovered.push(record); + break; // do not read past signed portion! + } + AlepnglowVotor::Nonce(rx_nonce) => { + nonce[0..8].copy_from_slice(&rx_nonce.to_be_bytes()); + } + _ => recovered.push(record), + } + signed_portion += size; + } + Ok(recovered) +}