diff --git a/ipa-core/src/protocol/dp/mod.rs b/ipa-core/src/protocol/dp/mod.rs index cde10f1b7..3510495b8 100644 --- a/ipa-core/src/protocol/dp/mod.rs +++ b/ipa-core/src/protocol/dp/mod.rs @@ -58,7 +58,7 @@ where let aggregation_input = Box::pin(stream::iter(vector_input_to_agg.into_iter()).map(Ok)); // Step 3: Call `aggregate_values` to sum up Bernoulli noise. let noise_vector: Result>, Error> = - aggregate_values::<_, OV, B>(ctx, aggregation_input, num_bernoulli as usize).await; + aggregate_values::<_, OV, B>(ctx, aggregation_input, num_bernoulli as usize, false).await; noise_vector } /// `apply_dp_noise` takes the noise distribution parameters (`num_bernoulli` and in the future `quantization_scale`) diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs index 14e722fd9..9d76b5b8a 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs @@ -188,7 +188,8 @@ where .try_flatten_iters(), ); let aggregated_result = - aggregate_values::<_, HV, B>(ctx, aggregation_input, contributions_stream_len).await?; + aggregate_values::<_, HV, B>(ctx, aggregation_input, contributions_stream_len, false) + .await?; Ok(Vec::transposed_from(&aggregated_result)?) } @@ -218,6 +219,7 @@ pub async fn aggregate_values<'ctx, 'fut, C, OV, const B: usize>( ctx: C, mut aggregated_stream: Pin> + Send + 'fut>>, mut num_rows: usize, + truncate: bool, ) -> Result>, Error> where 'ctx: 'fut, @@ -271,6 +273,15 @@ where .await?; sum.push(carry); Ok(sum) + } else if truncate { + let (sum, _) = integer_add::<_, SixteenBitStep, B>( + ctx.narrow(&AggregateValuesStep::Add), + record_id, + &a, + &b, + ) + .await?; + Ok(sum) } else { assert!( OV::BITS <= SixteenBitStep::BITS, @@ -361,7 +372,12 @@ pub mod tests { let result: BitDecomposed = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>( + ctx, + stream::iter(inputs).boxed(), + num_rows, + false, + ) }) .await .map(Result::unwrap) @@ -384,7 +400,12 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>( + ctx, + stream::iter(inputs).boxed(), + num_rows, + false, + ) }) .await .map(Result::unwrap) @@ -410,7 +431,12 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>( + ctx, + stream::iter(inputs).boxed(), + num_rows, + false, + ) }) .await .map(Result::unwrap) @@ -438,7 +464,12 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>( + ctx, + stream::iter(inputs).boxed(), + num_rows, + false, + ) }) .await .map(Result::unwrap) @@ -465,7 +496,12 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>( + ctx, + stream::iter(inputs).boxed(), + num_rows, + false, + ) }) .await .map(Result::unwrap) @@ -484,7 +520,7 @@ pub mod tests { run(|| async move { let result = TestWorld::default() .upgraded_semi_honest((), |ctx, ()| { - aggregate_values::<_, BA8, 8>(ctx, stream::empty().boxed(), 0) + aggregate_values::<_, BA8, 8>(ctx, stream::empty().boxed(), 0, false) }) .await .map(Result::unwrap) @@ -505,7 +541,12 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>( + ctx, + stream::iter(inputs).boxed(), + num_rows, + false, + ) }) .await; @@ -525,7 +566,12 @@ pub mod tests { let _ = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len() + 1; - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>( + ctx, + stream::iter(inputs).boxed(), + num_rows, + false, + ) }) .await .map(Result::unwrap) @@ -547,7 +593,12 @@ pub mod tests { let _ = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len() - 1; - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>( + ctx, + stream::iter(inputs).boxed(), + num_rows, + false, + ) }) .await .map(Result::unwrap) @@ -633,6 +684,7 @@ pub mod tests { ctx, stream::iter(inputs).boxed(), num_rows, + false, ) }) .await diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/sigmoid.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/sigmoid.rs index 5fc7eb5c6..9fdebad6a 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/sigmoid.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/sigmoid.rs @@ -1,15 +1,28 @@ -use std::{iter::repeat, ops::Not}; +use std::{ + iter::{repeat, zip}, + ops::Not, +}; -use futures::future::{try_join, try_join4, try_join5}; +use futures::{ + future::{try_join, try_join4, try_join5}, + stream, StreamExt, +}; +use super::multiplication::integer_mul; use crate::{ - error::Error, - ff::boolean::Boolean, + error::{Error, LengthError}, + ff::{boolean::Boolean, boolean_array::BA8}, + helpers::{repeat_n, TotalRecords}, protocol::{ - basics::mul::SecureMul, boolean::step::ThirtyTwoBitStep, context::Context, + basics::mul::SecureMul, + boolean::{step::ThirtyTwoBitStep, NBitStep}, + context::Context, + ipa_prf::aggregation::aggregate_values, BooleanProtocols, RecordId, }, - secret_sharing::{replicated::semi_honest::AdditiveShare, BitDecomposed, FieldSimd}, + secret_sharing::{ + replicated::semi_honest::AdditiveShare, BitDecomposed, FieldSimd, TransposeFrom, + }, }; async fn a_times_b_and_not_b( @@ -158,13 +171,80 @@ where ])) } +// edge_weights[0] holds all the edge weights coming _out_ from the first neuron in the previous layer +pub async fn one_layer( + ctx: C, + last_layer_neurons: Vec>, + edge_weights: I, +) -> Result>, Error> +where + C: Context, + S: NBitStep, + Boolean: FieldSimd, + AdditiveShare: BooleanProtocols, + I: IntoIterator>>, + BitDecomposed>: + for<'a> TransposeFrom<&'a Vec>, Error = LengthError>, +{ + let multiplication_ctx = ctx.narrow("activation_times_edge_weight"); + + let contributions_per_neuron_in_last_layer: Vec>> = ctx + .parallel_join(zip(edge_weights, last_layer_neurons).enumerate().map( + |(i, (outbound_edge_weights, last_layer_neuron))| { + let repeated_neuron_activation = BitDecomposed::transposed_from( + &repeat_n(last_layer_neuron, N).collect::>(), + ) + .unwrap(); + let c = multiplication_ctx.clone(); + async move { + let lossless_result = integer_mul::<_, S, N>( + c, + RecordId::from(i), + &repeated_neuron_activation, + &outbound_edge_weights, + ) + .await?; + // Neuron activtion is an 8-bit value meant to represent a + // fractional number in the range [0, 1) + // So after multiplying this value with the edge weight, + // we must shift 8 bits down to effectively divide by 256 + let (_, top_8_bits) = lossless_result.split_at(8); + Ok::<_, Error>(top_8_bits) + } + }, + )) + .await?; + + let total_input = aggregate_values::<_, BA8, N>( + ctx.narrow("aggregated_edge_weights"), + Box::pin(stream::iter(contributions_per_neuron_in_last_layer.into_iter()).map(Ok)), + N, + true, + ) + .await?; + + sigmoid::<_, N>( + ctx.narrow("sigmoid") + .set_total_records(TotalRecords::Indeterminate), + RecordId::FIRST, + &total_input, + ) + .await +} + #[cfg(all(test, unit_test))] mod test { - use std::num::TryFromIntError; + use std::{iter::zip, num::TryFromIntError}; + + use rand::{thread_rng, Rng}; + use super::one_layer; use crate::{ ff::{boolean_array::BA8, U128Conversions}, - protocol::{context::Context, ipa_prf::boolean_ops::sigmoid::sigmoid, RecordId}, + protocol::{ + boolean::step::DefaultBitStep, context::Context, + ipa_prf::boolean_ops::sigmoid::sigmoid, RecordId, + }, secret_sharing::{BitDecomposed, SharedValue, TransposeFrom}, test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld}, @@ -237,4 +317,84 @@ mod test { } }); } + + #[test] + #[allow(clippy::cast_precision_loss)] + fn semi_honest_neural_network() { + run(|| async move { + let world = TestWorld::default(); + + let mut rng = thread_rng(); + + let edge_weights_matrix = (0..32) + .map(|i| { + (0..32) + .map(|j| { + // offset is in the range [-16, 16) + let offset = (3 * i + 5 * j) % 32 - 16; + let modulo = (256 + offset) % 256; + BA8::truncate_from(modulo as u128) + }) + .collect::>() + }) + .collect::>(); + let prev_neurons = (0..32).map(|_| rng.gen::()).collect::>(); + + let result: Vec = world + .upgraded_semi_honest( + ( + edge_weights_matrix + .clone() + .into_iter() + .map(|x| x.into_iter()), + prev_neurons.clone().into_iter(), + ), + |ctx, (edge_weights, prev_neurons)| async move { + let matrix_of_edge_weights = edge_weights + .iter() + .map(|chunk| BitDecomposed::transposed_from(chunk).unwrap()); + let result = one_layer::<_, DefaultBitStep, _, 32>( + ctx.set_total_records(32), + prev_neurons, + matrix_of_edge_weights, + ) + .await + .unwrap(); + + Vec::transposed_from(&result).unwrap() + }, + ) + .await + .reconstruct(); + + let expected_activations = zip(edge_weights_matrix, prev_neurons) + .fold([0; 32], |mut acc, (edge_weights, n)| { + let contributions_from_neuron = edge_weights.into_iter().map(|e| { + let lossless = as_i128(e) * i128::try_from(n.as_u128()).unwrap(); + lossless >> 8 + }); + + acc.iter_mut() + .zip(contributions_from_neuron) + .for_each(|(a, c)| *a += c); + acc + }) + .map(|total_input| { + ( + total_input, + piecewise_linear_sigmoid_approximation(total_input).unwrap(), + ) + }); + + for ((total_input, expected_activation), actual_result) in + expected_activations.iter().zip(result) + { + println!( + "total_input: {:?}, expected_activation: {:?}, actual_result: {:?}", + total_input, expected_activation, actual_result + ); + assert_eq!(actual_result.as_u128(), *expected_activation); + } + }); + } } diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs index 024042f41..f1eeb3d06 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs @@ -264,7 +264,7 @@ where seq_join(sh_ctx.active_work(), stream::iter(chunked_user_results)).try_flatten_iters(), ); let aggregated_result: BitDecomposed> = - aggregate_values::<_, HV, B>(binary_m_ctx, flattened_stream, num_outputs).await?; + aggregate_values::<_, HV, B>(binary_m_ctx, flattened_stream, num_outputs, false).await?; let transposed_aggregated_result: Vec> = Vec::transposed_from(&aggregated_result)?;