diff --git a/ipa-core/src/ff/mod.rs b/ipa-core/src/ff/mod.rs index 590b7e88b..a5e92ccac 100644 --- a/ipa-core/src/ff/mod.rs +++ b/ipa-core/src/ff/mod.rs @@ -90,6 +90,46 @@ pub trait Serializable: Sized { .map_err(Into::into) .unwrap_infallible() } + + /// This method provides the same functionality as [`Self::deserialize`] without + /// compile-time guarantees on the size of `buf`. Therefore, it is not appropriate + /// to use in production code. It is provided as convenience method + /// for tests that are ok to panic. + /// + /// ## Panics + /// If the size of `buf` is not equal to `Self::Size` or if `buf` bytes + /// are not a valid representation of this instance. See [`Self::deserialize`] for + /// more details. + /// + /// [`Self::deserialize`]: Self::deserialize + #[cfg(test)] + #[must_use] + fn deserialize_from_slice(buf: &[u8]) -> Self { + use typenum::Unsigned; + + assert_eq!(buf.len(), Self::Size::USIZE); + + let mut arr = GenericArray::default(); + arr.copy_from_slice(buf); + Self::deserialize(&arr).unwrap() + } + + /// This method provides the same functionality as [`Self::serialize`] without + /// compile-time guarantees on the size of `buf`. Therefore, it is not appropriate + /// to use in production code. It is provided as convenience method + /// for tests that are ok to panic. + /// + /// ## Panics + /// If the size of `buf` is not equal to `Self::Size`. + #[cfg(test)] + fn serialize_to_slice(&self, buf: &mut [u8]) { + use typenum::Unsigned; + + assert_eq!(buf.len(), Self::Size::USIZE); + + let dest = GenericArray::<_, Self::Size>::from_mut_slice(buf); + self.serialize(dest); + } } pub trait ArrayAccess { diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index 9133247e3..819de8c4d 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -64,6 +64,10 @@ pub use gateway_exports::{Gateway, MpcReceivingEnd, SendingEnd, ShardReceivingEn pub use prss_protocol::negotiate as negotiate_prss; #[cfg(feature = "web-app")] pub use transport::WrappedAxumBodyStream; +#[cfg(feature = "in-memory-infra")] +pub use transport::{ + config as in_memory_config, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport, +}; pub use transport::{ make_owned_handler, query, routing, ApiError, BodyStream, BytesStream, HandlerBox, HandlerRef, HelperResponse, Identity as TransportIdentity, LengthDelimitedStream, LogErrors, NoQueryId, @@ -71,8 +75,6 @@ pub use transport::{ RouteParams, SingleRecordStream, StepBinding, StreamCollection, StreamKey, Transport, WrappedBoxBodyStream, }; -#[cfg(feature = "in-memory-infra")] -pub use transport::{InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport}; use typenum::{Const, ToUInt, Unsigned, U8}; use x25519_dalek::PublicKey; @@ -130,6 +132,20 @@ impl TryFrom for HelperIdentity { } } +impl TryFrom<&str> for HelperIdentity { + type Error = String; + + fn try_from(value: &str) -> std::result::Result { + for identity in HelperIdentity::make_three() { + if identity.as_str() == value { + return Ok(identity); + } + } + + Err(format!("{value} is not a valid helper identity")) + } +} + impl From for u8 { fn from(value: HelperIdentity) -> Self { value.id @@ -138,16 +154,7 @@ impl From for u8 { impl Debug for HelperIdentity { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self.id { - 1 => "A", - 2 => "B", - 3 => "C", - _ => unreachable!(), - } - ) + write!(f, "{}", self.as_str()) } } diff --git a/ipa-core/src/helpers/transport/in_memory/config.rs b/ipa-core/src/helpers/transport/in_memory/config.rs new file mode 100644 index 000000000..95834c318 --- /dev/null +++ b/ipa-core/src/helpers/transport/in_memory/config.rs @@ -0,0 +1,144 @@ +use std::borrow::Cow; + +use crate::{ + helpers::{HelperIdentity, Role, RoleAssignment}, + protocol::Gate, + sharding::ShardIndex, + sync::Arc, +}; + +pub type DynStreamInterceptor = Arc>; + +/// The interface for stream interceptors. +/// +/// It is used in test infrastructure to inspect +/// incoming streams and perform actions based on +/// their contents. +/// +/// The `peek` method takes a context object and a mutable reference +/// to the data buffer. It is responsible for inspecting the data +/// and performing any necessary actions based on the context. +pub trait StreamInterceptor: Send + Sync { + /// The context type for the stream peeker. + /// See [`InspectContext`] and [`MaliciousHelperContext`] for + /// details. + type Context; + + /// Inspects the stream data and performs any necessary actions. + /// The `data` buffer may be modified in-place. + /// + /// ## Implementation considerations + /// This method is free to mutate the `data` buffer + /// however it wants, but it needs to account for the following: + /// + /// ### Prime field streams + /// Corrupting streams that send data as sequences of serialized + /// [`PrimeField`] may cause `GreaterThanPrimeError` errors at + /// the serialization layer, instead of maybe intended malicious + /// validation failures. + /// + /// ### Boolean fields + /// Flipping bits in fixed-size bit strings is indistinguishable + /// from additive attacks without additional measures implemented + /// at the transport layer, like checksumming, share consistency + /// checks, etc. + fn peek(&self, ctx: &Self::Context, data: &mut Vec); +} + +impl) + Send + Sync + 'static> StreamInterceptor for F { + type Context = InspectContext; + + fn peek(&self, ctx: &Self::Context, data: &mut Vec) { + (self)(ctx, data); + } +} + +/// The general context provided to stream inspectors. +#[derive(Debug)] +pub struct InspectContext { + /// The shard index of this instance. + /// This is `None` for non-sharded helpers. + pub shard_index: Option, + /// The MPC identity of this instance. + /// The combination (`shard_index`, `identity`) + /// uniquely identifies a single shard within + /// a multi-sharded MPC system. + pub identity: HelperIdentity, + /// Helper that will receive this stream. + pub dest: Cow<'static, str>, + /// Circuit gate this stream is tied to. + pub gate: Gate, +} + +/// The no-op stream peeker, which does nothing. +/// This is used as a default value for stream +/// peekers that don't do anything. +#[inline] +#[must_use] +pub fn passthrough() -> Arc> { + Arc::new(|_ctx: &InspectContext, _data: &mut Vec| {}) +} + +/// This narrows the implementation of stream seeker +/// to a specific helper role. Only streams sent from +/// that helper will be inspected by the provided closure. +/// Other helper's streams will be left untouched. +/// +/// It does not support sharded environments and will panic +/// if used in a sharded test infrastructure. +#[derive(Debug)] +pub struct MaliciousHelper { + identity: HelperIdentity, + role_assignment: RoleAssignment, + inner: F, +} + +impl) + Send + Sync> MaliciousHelper { + pub fn new(role: Role, role_assignment: &RoleAssignment, peeker: F) -> Arc { + Arc::new(Self { + identity: role_assignment.identity(role), + role_assignment: role_assignment.clone(), + inner: peeker, + }) + } + + fn context(&self, ctx: &InspectContext) -> MaliciousHelperContext { + let dest = HelperIdentity::try_from(ctx.dest.as_ref()) + .unwrap_or_else(|e| panic!("Can't resolve helper identity for {}: {e}", ctx.dest)); + let dest = self.role_assignment.role(dest); + + MaliciousHelperContext { + shard_index: ctx.shard_index, + dest, + gate: ctx.gate.clone(), + } + } +} + +/// Special contexts for stream inspectors +/// created with [`MaliciousHelper`]. +/// It provides convenient access to the +/// destination role and assumes a single MPC +/// helper intercepting streams. +#[derive(Debug)] +pub struct MaliciousHelperContext { + /// The shard index of this instance. + /// This is `None` for non-sharded helpers. + pub shard_index: Option, + /// Helper that will receive this stream. + pub dest: Role, + /// Circuit gate this stream is tied to. + pub gate: Gate, +} + +impl) + Send + Sync> StreamInterceptor + for MaliciousHelper +{ + type Context = InspectContext; + + fn peek(&self, ctx: &Self::Context, data: &mut Vec) { + if ctx.identity == self.identity { + (self.inner)(&self.context(ctx), data); + } + } +} diff --git a/ipa-core/src/helpers/transport/in_memory/mod.rs b/ipa-core/src/helpers/transport/in_memory/mod.rs index 2187c20bf..16617e514 100644 --- a/ipa-core/src/helpers/transport/in_memory/mod.rs +++ b/ipa-core/src/helpers/transport/in_memory/mod.rs @@ -1,13 +1,16 @@ +pub mod config; mod sharding; mod transport; -use std::array; - pub use sharding::InMemoryShardNetwork; pub use transport::Setup; +use transport::TransportConfigBuilder; use crate::{ - helpers::{HandlerRef, HelperIdentity}, + helpers::{ + in_memory_config::DynStreamInterceptor, transport::in_memory::config::passthrough, + HandlerRef, HelperIdentity, + }, sync::{Arc, Weak}, }; @@ -21,15 +24,32 @@ pub struct InMemoryMpcNetwork { impl Default for InMemoryMpcNetwork { fn default() -> Self { - Self::new(array::from_fn(|_| None)) + Self::new(Self::noop_handlers()) } } impl InMemoryMpcNetwork { + #[must_use] + pub fn noop_handlers() -> [Option; 3] { + [None, None, None] + } + #[must_use] pub fn new(handlers: [Option; 3]) -> Self { - let [mut first, mut second, mut third]: [_; 3] = - HelperIdentity::make_three().map(Setup::new); + Self::with_stream_interceptor(handlers, &passthrough()) + } + + #[must_use] + pub fn with_stream_interceptor( + handlers: [Option; 3], + interceptor: &DynStreamInterceptor, + ) -> Self { + let [mut first, mut second, mut third]: [_; 3] = HelperIdentity::make_three().map(|i| { + let mut config_builder = TransportConfigBuilder::for_helper(i); + config_builder.with_interceptor(interceptor); + + Setup::with_config(i, config_builder.not_sharded()) + }); first.connect(&mut second); second.connect(&mut third); diff --git a/ipa-core/src/helpers/transport/in_memory/sharding.rs b/ipa-core/src/helpers/transport/in_memory/sharding.rs index e76a7496a..e456dfa2a 100644 --- a/ipa-core/src/helpers/transport/in_memory/sharding.rs +++ b/ipa-core/src/helpers/transport/in_memory/sharding.rs @@ -1,6 +1,7 @@ use crate::{ helpers::{ - transport::in_memory::transport::{InMemoryTransport, Setup}, + in_memory_config::{passthrough, DynStreamInterceptor}, + transport::in_memory::transport::{InMemoryTransport, Setup, TransportConfigBuilder}, HelperIdentity, }, sharding::ShardIndex, @@ -22,9 +23,22 @@ pub struct InMemoryShardNetwork { impl InMemoryShardNetwork { pub fn with_shards>(shard_count: I) -> Self { + Self::with_stream_interceptor(shard_count, &passthrough()) + } + + pub fn with_stream_interceptor>( + shard_count: I, + interceptor: &DynStreamInterceptor, + ) -> Self { let shard_count = shard_count.into(); let shard_network: [_; 3] = HelperIdentity::make_three().map(|h| { - let mut shard_connections = shard_count.iter().map(Setup::new).collect::>(); + let mut config_builder = TransportConfigBuilder::for_helper(h); + config_builder.with_interceptor(interceptor); + + let mut shard_connections = shard_count + .iter() + .map(|i| Setup::with_config(i, config_builder.bind_to_shard(i))) + .collect::>(); for i in 0..shard_connections.len() { let (lhs, rhs) = shard_connections.split_at_mut(i); if let Some((a, _)) = lhs.split_last_mut() { diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index f4300b6b9..3c1a9e926 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -21,12 +21,18 @@ use tracing::Instrument; use crate::{ error::BoxError, helpers::{ - transport::routing::{Addr, RouteId}, - ApiError, BodyStream, HandlerRef, HelperResponse, NoResourceIdentifier, QueryIdBinding, - ReceiveRecords, RequestHandler, RouteParams, StepBinding, StreamCollection, Transport, - TransportIdentity, + in_memory_config, + in_memory_config::DynStreamInterceptor, + transport::{ + in_memory::config::InspectContext, + routing::{Addr, RouteId}, + }, + ApiError, BodyStream, HandlerRef, HelperIdentity, HelperResponse, NoResourceIdentifier, + QueryIdBinding, ReceiveRecords, RequestHandler, RouteParams, StepBinding, StreamCollection, + Transport, TransportIdentity, }, protocol::{Gate, QueryId}, + sharding::ShardIndex, sync::{Arc, Weak}, }; @@ -66,15 +72,21 @@ pub struct InMemoryTransport { identity: I, connections: HashMap>, record_streams: StreamCollection, + config: TransportConfig, } impl InMemoryTransport { #[must_use] - fn new(identity: I, connections: HashMap>) -> Self { + fn with_config( + identity: I, + connections: HashMap>, + config: TransportConfig, + ) -> Self { Self { identity, connections, record_streams: StreamCollection::default(), + config, } } @@ -170,12 +182,27 @@ impl Transport for Weak> { let this = self.upgrade().unwrap(); let channel = this.get_channel(dest); let addr = Addr::from_route(Some(this.identity), route); + let gate = addr.gate.clone(); + let (ack_tx, ack_rx) = oneshot::channel(); + let context = gate.map(|gate| InspectContext { + shard_index: this.config.shard_index, + identity: this.config.identity, + dest: dest.as_str(), + gate, + }); channel .send(( addr, - InMemoryStream::wrap(data.map(Bytes::from).map(Ok)), + InMemoryStream::wrap(data.map({ + move |mut chunk| { + if let Some(ref context) = context { + this.config.stream_interceptor.peek(context, &mut chunk); + } + Ok(Bytes::from(chunk)) + } + })), ack_tx, )) .await @@ -251,17 +278,30 @@ pub struct Setup { tx: ConnectionTx, rx: ConnectionRx, connections: HashMap>, + config: TransportConfig, +} + +impl Setup { + #[must_use] + #[allow(unused)] + pub fn new(identity: HelperIdentity) -> Self { + Self::with_config( + identity, + TransportConfigBuilder::for_helper(identity).not_sharded(), + ) + } } impl Setup { #[must_use] - pub fn new(identity: I) -> Self { + pub fn with_config(identity: I, config: TransportConfig) -> Self { let (tx, rx) = channel(16); Self { identity, tx, rx, connections: HashMap::default(), + config, } } @@ -288,7 +328,11 @@ impl Setup { self, handler: Option>, ) -> (ConnectionTx, Arc>) { - let transport = Arc::new(InMemoryTransport::new(self.identity, self.connections)); + let transport = Arc::new(InMemoryTransport::with_config( + self.identity, + self.connections, + self.config, + )); transport.listen(handler, self.rx); (self.tx, transport) @@ -579,3 +623,45 @@ mod tests { assert_eq!(vec![vec![0, 1]], recv.collect::>().await); } } + +pub struct TransportConfig { + pub shard_index: Option, + pub identity: HelperIdentity, + pub stream_interceptor: DynStreamInterceptor, +} + +pub struct TransportConfigBuilder { + identity: HelperIdentity, + stream_interceptor: DynStreamInterceptor, +} + +impl TransportConfigBuilder { + pub fn for_helper(identity: HelperIdentity) -> Self { + Self { + identity, + stream_interceptor: in_memory_config::passthrough(), + } + } + + pub fn with_interceptor(&mut self, interceptor: &DynStreamInterceptor) -> &mut Self { + self.stream_interceptor = Arc::clone(interceptor); + + self + } + + pub fn bind_to_shard(&self, shard_index: ShardIndex) -> TransportConfig { + TransportConfig { + shard_index: Some(shard_index), + identity: self.identity, + stream_interceptor: Arc::clone(&self.stream_interceptor), + } + } + + pub fn not_sharded(&self) -> TransportConfig { + TransportConfig { + shard_index: None, + identity: self.identity, + stream_interceptor: Arc::clone(&self.stream_interceptor), + } + } +} diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index 37f45c449..c3bb307d8 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -24,7 +24,7 @@ pub use handler::{ make_owned_handler, Error as ApiError, HandlerBox, HandlerRef, HelperResponse, RequestHandler, }; #[cfg(feature = "in-memory-infra")] -pub use in_memory::{InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport}; +pub use in_memory::{config, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport}; pub use receive::{LogErrors, ReceiveRecords}; #[cfg(feature = "web-app")] pub use stream::WrappedAxumBodyStream; @@ -53,7 +53,12 @@ impl Identity for ShardIndex { } impl Identity for HelperIdentity { fn as_str(&self) -> Cow<'static, str> { - Cow::Owned(self.id.to_string()) + Cow::Borrowed(match *self { + Self::ONE => "A", + Self::TWO => "B", + Self::THREE => "C", + _ => unreachable!(), + }) } } diff --git a/ipa-core/src/protocol/basics/reshare.rs b/ipa-core/src/protocol/basics/reshare.rs index 2ba465137..f8e3019f5 100644 --- a/ipa-core/src/protocol/basics/reshare.rs +++ b/ipa-core/src/protocol/basics/reshare.rs @@ -199,35 +199,21 @@ mod tests { } mod malicious { - use futures::future::try_join; + use rand::{distributions::Standard, prelude::Distribution}; use crate::{ error::Error, ff::{Field, Fp32BitPrime, Gf2, Gf32Bit}, - helpers::{Direction, Role}, + helpers::{in_memory_config::MaliciousHelper, Role}, protocol::{ - basics::{ - mul::step::MaliciousMultiplyStep::{RandomnessForValidation, ReshareRx}, - Reshare, - }, - context::{ - Context, SemiHonestContext, UpgradableContext, UpgradedContext, - UpgradedMaliciousContext, Validator, - }, - prss::SharedRandomness, + basics::Reshare, + context::{Context, UpgradableContext, UpgradedContext, Validator}, RecordId, }, rand::{thread_rng, Rng}, - secret_sharing::{ - replicated::{ - malicious::{AdditiveShare as MaliciousReplicated, ExtendableField}, - semi_honest::AdditiveShare as Replicated, - ReplicatedSecretSharing, - }, - SharedValue, - }, - test_fixture::{Reconstruct, Runner, TestWorld}, + secret_sharing::{replicated::malicious::ExtendableField, SharedValue}, + test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig}, }; /// Relies on semi-honest protocol tests that enforce reshare to communicate and produce @@ -255,83 +241,6 @@ mod tests { } } - async fn reshare_with_additive_attack( - ctx: C, - input: &Replicated, - record_id: RecordId, - to_helper: Role, - additive_error: F, - ) -> Result, Error> { - let (r0, r1) = ctx.prss().generate_fields(record_id); - - // `to_helper.left` calculates part1 = (input.0 + input.1) - r1 and sends part1 to `to_helper.right` - // This is same as (a1 + a2) - r2 in the diagram - if ctx.role() == to_helper.peer(Direction::Left) { - let send_channel = ctx.send_channel(to_helper.peer(Direction::Right)); - let receive_channel = ctx.recv_channel(to_helper.peer(Direction::Right)); - - let part1 = input.left() + input.right() - r1 + additive_error; - send_channel.send(record_id, part1).await?; - - // Sleep until `to_helper.right` sends us their part2 value - let part2 = receive_channel.receive(record_id).await?; - - Ok(Replicated::new(part1 + part2, r1)) - } else if ctx.role() == to_helper.peer(Direction::Right) { - let send_channel = ctx.send_channel(to_helper.peer(Direction::Left)); - let receive_channel = ctx.recv_channel::(to_helper.peer(Direction::Left)); - - // `to_helper.right` calculates part2 = (input.left() - r0) and sends it to `to_helper.left` - // This is same as (a3 - r3) in the diagram - let part2 = input.left() - r0 + additive_error; - send_channel.send(record_id, part2).await?; - - // Sleep until `to_helper.left` sends us their part1 value - let part1 = receive_channel.receive(record_id).await?; - - Ok(Replicated::new(r0, part1 + part2)) - } else { - Ok(Replicated::new(r0, r1)) - } - } - - async fn reshare_malicious_with_additive_attack( - ctx: UpgradedMaliciousContext<'_, F>, - input: &MaliciousReplicated, - record_id: RecordId, - to_helper: Role, - small_field_additive_error: F, - large_field_additive_error: F::ExtendedField, - ) -> Result, Error> { - use crate::{ - protocol::context::SpecialAccessToUpgradedContext, - secret_sharing::replicated::malicious::ThisCodeIsAuthorizedToDowngradeFromMalicious, - }; - let random_constant_ctx = ctx.narrow(&RandomnessForValidation); - - let (rx, x) = try_join( - reshare_with_additive_attack( - SemiHonestContext::from_base(ctx.narrow(&ReshareRx).base_context()), - input.rx(), - record_id, - to_helper, - large_field_additive_error, - ), - reshare_with_additive_attack( - SemiHonestContext::from_base(ctx.base_context()), - input.x().access_without_downgrade(), - record_id, - to_helper, - small_field_additive_error, - ), - ) - .await?; - let malicious_input = MaliciousReplicated::new(x, rx); - - random_constant_ctx.accumulate_macs(record_id, &malicious_input); - Ok(malicious_input) - } - #[tokio::test] async fn fp32bit_reshare_validation_fail() { const PERTURBATIONS: [(Fp32BitPrime, Fp32BitPrime); 3] = [ @@ -357,37 +266,48 @@ mod tests { F: ExtendableField, Standard: Distribution, { - let world = TestWorld::default(); - let mut rng = thread_rng(); - - let a = rng.gen::(); + const STEP: &str = "malicious-attack"; + + /// Corrupts a single value `F` by running an additive attack. + /// `binary_data` must carry the exact one value of `F` and the result + /// will be written back to it, so the attack is run in place + fn corrupt(binary_data: &mut [u8], add: F) { + let v = F::deserialize_from_slice(binary_data) + add; + v.serialize_to_slice(binary_data); + } - let to_helper = Role::H1; + for (small_value, large_value) in perturbations.iter().copied() { + for malicious_actor in [Role::H2, Role::H3] { + let mut config = TestWorldConfig::default(); + config.stream_interceptor = MaliciousHelper::new( + malicious_actor, + config.role_assignment(), + move |ctx, data| { + if ctx.gate.as_ref().contains(STEP) { + if ctx.gate.as_ref().contains("reshare_rx") { + corrupt(data, large_value); + } else { + corrupt(data, small_value); + } + } + }, + ); + let world = TestWorld::new_with(&config); + let mut rng = thread_rng(); + let a = rng.gen::(); + let to_helper = Role::H1; - for perturbation in perturbations { - for malicious_actor in &[Role::H2, Role::H3] { world .malicious(a, |ctx, a| async move { let v = ctx.validator(); let m_ctx = v.context().set_total_records(1); - let record_id = RecordId::from(0); let m_a = v.context().upgrade(a).await.unwrap(); - let m_reshared_a = if m_ctx.role() == *malicious_actor { - // This role is spoiling the value. - reshare_malicious_with_additive_attack( - m_ctx, - &m_a, - record_id, - to_helper, - perturbation.0, - perturbation.1, - ) + let m_reshared_a = m_a + .reshare(m_ctx.narrow(STEP), RecordId::FIRST, to_helper) .await - .unwrap() - } else { - m_a.reshare(m_ctx, record_id, to_helper).await.unwrap() - }; + .unwrap(); + match v.validate(m_reshared_a).await { Ok(result) => panic!("Got a result {result:?}"), Err(err) => { diff --git a/ipa-core/src/protocol/basics/reveal.rs b/ipa-core/src/protocol/basics/reveal.rs index f0d98b018..1d59df603 100644 --- a/ipa-core/src/protocol/basics/reveal.rs +++ b/ipa-core/src/protocol/basics/reveal.rs @@ -206,32 +206,21 @@ where mod tests { use std::iter::zip; - use futures::future::{join_all, try_join, try_join3}; + use futures::future::join_all; use crate::{ error::Error, - ff::{Field, Fp31, Fp32BitPrime}, - helpers::{Direction, Role}, + ff::{Field, Fp31, Fp32BitPrime, Serializable}, + helpers::{in_memory_config::MaliciousHelper, Role}, protocol::{ basics::Reveal, - context::{ - Context, UpgradableContext, UpgradedContext, UpgradedMaliciousContext, Validator, - }, + context::{Context, UpgradableContext, UpgradedContext, Validator}, RecordId, }, rand::{thread_rng, Rng}, - secret_sharing::{ - replicated::{ - malicious::{ - AdditiveShare as MaliciousReplicated, ExtendableField, - ThisCodeIsAuthorizedToDowngradeFromMalicious, - }, - semi_honest::AdditiveShare, - }, - IntoShares, SharedValue, - }, + secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares, SharedValue}, test_executor::run, - test_fixture::{join3v, Runner, TestWorld}, + test_fixture::{join3v, Runner, TestWorld, TestWorldConfig}, }; #[tokio::test] @@ -395,100 +384,62 @@ mod tests { } #[test] - pub fn malicious_validation_fail() { - run(|| async { - let mut rng = thread_rng(); - let world = TestWorld::default(); - let sh_ctx = world.malicious_contexts(); - let v = sh_ctx.map(UpgradableContext::validator); - let m_ctx = v.each_ref().map(|v| v.context().set_total_records(1)); - - let record_id = RecordId::from(0); - let input: Fp31 = rng.gen(); - - let m_shares = join3v( - zip(m_ctx.iter(), input.share_with(&mut rng)) - .map(|(m_ctx, share)| async { m_ctx.upgrade(share).await }), - ) - .await; - let result = try_join3( - m_shares[0].reveal(m_ctx[0].clone(), record_id), - m_shares[1].reveal(m_ctx[1].clone(), record_id), - reveal_with_additive_attack( - m_ctx[2].clone(), - record_id, - &m_shares[2], - false, - Fp31::ONE, - ), - ) - .await; - - assert!(matches!(result, Err(Error::MaliciousRevealFailed))); - }); + pub fn malicious_generic_validation_fail() { + let partial = false; + malicious_validation_fail(partial); } #[test] pub fn malicious_partial_validation_fail() { - run(|| async { - let mut rng = thread_rng(); - let world = TestWorld::default(); - let sh_ctx = world.malicious_contexts(); - let v = sh_ctx.map(UpgradableContext::validator); - let m_ctx: [_; 3] = v.each_ref().map(|v| v.context().set_total_records(1)); - - let record_id = RecordId::from(0); - let input: Fp31 = rng.gen(); - - let m_shares = join3v( - zip(m_ctx.iter(), input.share_with(&mut rng)) - .map(|(m_ctx, share)| async { m_ctx.upgrade(share).await }), - ) - .await; - let result = try_join3( - m_shares[0].partial_reveal(m_ctx[0].clone(), record_id, Role::H3), - m_shares[1].partial_reveal(m_ctx[1].clone(), record_id, Role::H3), - reveal_with_additive_attack( - m_ctx[2].clone(), - record_id, - &m_shares[2], - true, - Fp31::ONE, - ), - ) - .await; - - assert!(matches!(result, Err(Error::MaliciousRevealFailed))); - }); + let partial = true; + malicious_validation_fail(partial); } - pub async fn reveal_with_additive_attack( - ctx: UpgradedMaliciousContext<'_, F>, - record_id: RecordId, - input: &MaliciousReplicated, - excluded: bool, - additive_error: F, - ) -> Result, Error> { - let (left, right) = input.x().access_without_downgrade().as_tuple(); - let left_sender = ctx.send_channel(ctx.role().peer(Direction::Left)); - let right_sender = ctx.send_channel(ctx.role().peer(Direction::Right)); - let left_recv = ctx.recv_channel(ctx.role().peer(Direction::Left)); - let right_recv = ctx.recv_channel(ctx.role().peer(Direction::Right)); - - // Send share to helpers to the right and left - try_join( - left_sender.send(record_id, right), - right_sender.send(record_id, left + additive_error), - ) - .await?; - - if excluded { - Ok(None) - } else { - let (share_from_left, _share_from_right): (F, F) = - try_join(left_recv.receive(record_id), right_recv.receive(record_id)).await?; + pub fn malicious_validation_fail(partial: bool) { + const STEP: &str = "malicious-reveal"; - Ok(Some(left + right + share_from_left)) - } + run(move || async move { + let mut rng = thread_rng(); + let mut config = TestWorldConfig::default(); + config.stream_interceptor = + MaliciousHelper::new(Role::H3, config.role_assignment(), move |ctx, data| { + // H3 runs an additive attack against H1 (on the right) by + // adding a 1 to the left part of share it is holding + if ctx.gate.as_ref().contains(STEP) && ctx.dest == Role::H1 { + let v = Fp31::deserialize_from_slice(data) + Fp31::ONE; + v.serialize_to_slice(data); + } + }); + + let world = TestWorld::new_with(config); + let input: Fp31 = rng.gen(); + world + .malicious(input, |ctx, share| async move { + let v = ctx.validator(); + let m_ctx = v.context().set_total_records(1); + let malicious_share = v.context().upgrade(share).await.unwrap(); + let m_ctx = m_ctx.narrow(STEP); + let my_role = m_ctx.role(); + + let r = if partial { + malicious_share + .partial_reveal(m_ctx, RecordId::FIRST, Role::H3) + .await + } else { + malicious_share + .generic_reveal(m_ctx, RecordId::FIRST, None) + .await + }; + + // H1 should be able to see the mismatch + if my_role == Role::H1 { + assert!(matches!(r, Err(Error::MaliciousRevealFailed))); + } else { + // sanity check + r.unwrap(); + } + }) + .await; + }); } } diff --git a/ipa-core/src/protocol/context/validator.rs b/ipa-core/src/protocol/context/validator.rs index 8accd0e90..c1845dcc1 100644 --- a/ipa-core/src/protocol/context/validator.rs +++ b/ipa-core/src/protocol/context/validator.rs @@ -170,7 +170,6 @@ impl MaliciousAccumulator { // Then, the parties call `Ḟ_product` on vectors // `([[ᾶ_1]], . . . , [[ᾶ_N ]], [[β_1]], . . . , [[β_M]])` and `([[z_1]], . . . , [[z_N]], [[v_1]], . . . , [[v_M]])` to receive `[[ŵ]]` let induced_share = Replicated::new(x.left().to_extended(), x.right().to_extended()); - let random_constant = prss.generate(record_id); let u_contribution: F::ExtendedField = Self::compute_dot_product_contribution(&random_constant, input.rx()); diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index ad0778649..60fb642be 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -13,6 +13,7 @@ use tracing::{Instrument, Level, Span}; use crate::{ helpers::{ + in_memory_config::{passthrough, DynStreamInterceptor}, Gateway, GatewayConfig, HelperIdentity, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport, Role, RoleAssignment, Transport, }, @@ -109,6 +110,34 @@ pub struct TestWorldConfig { /// performance and want to use compact gates, you set this to the gate narrowed to root step /// of the protocol being tested. pub initial_gate: Option, + + /// An optional interceptor to be put inside the in-memory stream + /// module. This allows inspecting and modifying stream content + /// for each communication round between any pair of helpers. + /// The application include: + /// * Malicious behavior. This can help simulating a malicious + /// actor being present in the system by running one or several + /// additive attacks. + /// * Data corruption. Tests can simulate bit flips that occur + /// at the network layer and check whether IPA can recover from + /// these (checksums, etc). + /// + /// The interface is pretty low level because of the layer + /// where it operates. [`StreamInterceptor`] interface provides + /// access to the circuit gate and raw bytes being + /// sent between helpers and/or shards. [`MaliciousHelper`] + /// is one example of helper that could be built on top + /// of this generic interface. It is recommended to build + /// a custom interceptor for repeated use-cases that is less + /// generic than [`StreamInterceptor`]. + /// + /// If interception is not required, the [`passthrough`] interceptor + /// may be used. + /// + /// [`StreamInterceptor`]: crate::helpers::in_memory_config::StreamInterceptor + /// [`MaliciousHelper`]: crate::helpers::in_memory_config::MaliciousHelper + /// [`passthrough`]: crate::helpers::in_memory_config::passthrough + pub stream_interceptor: DynStreamInterceptor, } impl ShardingScheme for NotSharded { @@ -236,7 +265,8 @@ impl TestWorld { println!("TestWorld random seed {seed}", seed = config.seed); let shard_count = ShardIndex::try_from(S::SHARDS).unwrap(); - let shard_network = InMemoryShardNetwork::with_shards(shard_count); + let shard_network = + InMemoryShardNetwork::with_stream_interceptor(shard_count, &config.stream_interceptor); let shards = shard_count .iter() @@ -285,6 +315,7 @@ impl Default for TestWorldConfig { role_assignment: None, seed: thread_rng().next_u64(), initial_gate: None, + stream_interceptor: passthrough(), } } } @@ -602,9 +633,11 @@ impl ShardWorld { shard_seed: u64, transports: [InMemoryTransport; 3], ) -> Self { - // todo: B -> seed let participants = make_participants(&mut StdRng::seed_from_u64(config.seed + shard_seed)); - let network = InMemoryMpcNetwork::default(); + let network = InMemoryMpcNetwork::with_stream_interceptor( + InMemoryMpcNetwork::noop_handlers(), + &config.stream_interceptor, + ); let mut gateways = zip3_ref(&network.transports(), &transports).map(|(mpc, shard)| { Gateway::new( @@ -720,9 +753,19 @@ mod tests { sync::{Arc, Mutex}, }; + use futures_util::future::try_join4; + use crate::{ - ff::{boolean_array::BA3, U128Conversions}, - protocol::{context::Context, prss::SharedRandomness}, + ff::{boolean_array::BA3, Field, Fp31, U128Conversions}, + helpers::{ + in_memory_config::{MaliciousHelper, MaliciousHelperContext}, + Direction, Role, + }, + protocol::{context::Context, prss::SharedRandomness, RecordId}, + secret_sharing::{ + replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, + SharedValue, + }, sharding::ShardConfiguration, test_executor::run, test_fixture::{world::WithShards, Reconstruct, Runner, TestWorld, TestWorldConfig}, @@ -785,4 +828,67 @@ mod tests { .await.into_iter().map(|v| v.reconstruct()).collect::>(); }); } + + #[test] + fn peeker_can_corrupt_data() { + const STEP: &str = "corruption"; + run(|| async move { + fn corrupt_byte(data: &mut u8) { + // flipping the bit may result in prime overflow, + // so we just set the value to be 0 or 1 if it was 0 + if *data == 0 { + *data = 1; + } else { + *data = 0; + } + } + + let mut config = TestWorldConfig::default(); + config.stream_interceptor = MaliciousHelper::new( + Role::H1, + config.role_assignment(), + |ctx: &MaliciousHelperContext, data: &mut Vec| { + if ctx.gate.as_ref().contains(STEP) { + corrupt_byte(&mut data[0]); + } + }, + ); + + let world = TestWorld::new_with(config); + + let shares = world + .semi_honest((), |ctx, ()| async move { + let ctx = ctx.narrow(STEP).set_total_records(1); + let (l, r): (Fp31, Fp31) = ctx.prss().generate(RecordId::FIRST); + + let ((), (), r, l) = try_join4( + ctx.send_channel(ctx.role().peer(Direction::Right)) + .send(RecordId::FIRST, r), + ctx.send_channel(ctx.role().peer(Direction::Left)) + .send(RecordId::FIRST, l), + ctx.recv_channel::(ctx.role().peer(Direction::Right)) + .receive(RecordId::FIRST), + ctx.recv_channel::(ctx.role().peer(Direction::Left)) + .receive(RecordId::FIRST), + ) + .await + .unwrap(); + + AdditiveShare::new(l, r) + }) + .await; + + println!("{shares:?}"); + // shares received from H1 must be corrupted + assert_ne!(shares[0].right(), shares[1].left()); + assert_ne!(shares[0].left(), shares[2].right()); + + // and must be set to either 0 or 1 + assert!([Fp31::ZERO, Fp31::ONE].contains(&shares[1].left())); + assert!([Fp31::ZERO, Fp31::ONE].contains(&shares[2].right())); + + // values shared between H2 and H3 must be consistent + assert_eq!(shares[1].right(), shares[2].left()); + }); + } }