Skip to content
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/dp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ where
let aggregation_input = Box::pin(stream::iter(vector_input_to_agg.into_iter()).map(Ok));
// Step 3: Call `aggregate_values` to sum up Bernoulli noise.
let noise_vector: Result<BitDecomposed<AdditiveShare<Boolean, { B }>>, Error> =
aggregate_values::<_, OV, B>(ctx, aggregation_input, num_bernoulli as usize).await;
aggregate_values::<_, OV, B>(ctx, aggregation_input, num_bernoulli as usize, false).await;
noise_vector
}
/// `apply_dp_noise` takes the noise distribution parameters (`num_bernoulli` and in the future `quantization_scale`)
Expand Down
72 changes: 62 additions & 10 deletions ipa-core/src/protocol/ipa_prf/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ where
.try_flatten_iters(),
);
let aggregated_result =
aggregate_values::<_, HV, B>(ctx, aggregation_input, contributions_stream_len).await?;
aggregate_values::<_, HV, B>(ctx, aggregation_input, contributions_stream_len, false)
.await?;
Ok(Vec::transposed_from(&aggregated_result)?)
}

Expand Down Expand Up @@ -218,6 +219,7 @@ pub async fn aggregate_values<'ctx, 'fut, C, OV, const B: usize>(
ctx: C,
mut aggregated_stream: Pin<Box<dyn Stream<Item = AggResult<B>> + Send + 'fut>>,
mut num_rows: usize,
truncate: bool,
) -> Result<BitDecomposed<Replicated<Boolean, B>>, Error>
where
'ctx: 'fut,
Expand Down Expand Up @@ -271,6 +273,15 @@ where
.await?;
sum.push(carry);
Ok(sum)
} else if truncate {
let (sum, _) = integer_add::<_, SixteenBitStep, B>(
ctx.narrow(&AggregateValuesStep::Add),
record_id,
&a,
&b,
)
.await?;
Ok(sum)
} else {
assert!(
OV::BITS <= SixteenBitStep::BITS,
Expand Down Expand Up @@ -361,7 +372,12 @@ pub mod tests {
let result: BitDecomposed<BA8> = TestWorld::default()
.upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| {
let num_rows = inputs.len();
aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows)
aggregate_values::<_, BA8, 8>(
ctx,
stream::iter(inputs).boxed(),
num_rows,
false,
)
})
.await
.map(Result::unwrap)
Expand All @@ -384,7 +400,12 @@ pub mod tests {
let result = TestWorld::default()
.upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| {
let num_rows = inputs.len();
aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows)
aggregate_values::<_, BA8, 8>(
ctx,
stream::iter(inputs).boxed(),
num_rows,
false,
)
})
.await
.map(Result::unwrap)
Expand All @@ -410,7 +431,12 @@ pub mod tests {
let result = TestWorld::default()
.upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| {
let num_rows = inputs.len();
aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows)
aggregate_values::<_, BA8, 8>(
ctx,
stream::iter(inputs).boxed(),
num_rows,
false,
)
})
.await
.map(Result::unwrap)
Expand Down Expand Up @@ -438,7 +464,12 @@ pub mod tests {
let result = TestWorld::default()
.upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| {
let num_rows = inputs.len();
aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows)
aggregate_values::<_, BA8, 8>(
ctx,
stream::iter(inputs).boxed(),
num_rows,
false,
)
})
.await
.map(Result::unwrap)
Expand All @@ -465,7 +496,12 @@ pub mod tests {
let result = TestWorld::default()
.upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| {
let num_rows = inputs.len();
aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows)
aggregate_values::<_, BA8, 8>(
ctx,
stream::iter(inputs).boxed(),
num_rows,
false,
)
})
.await
.map(Result::unwrap)
Expand All @@ -484,7 +520,7 @@ pub mod tests {
run(|| async move {
let result = TestWorld::default()
.upgraded_semi_honest((), |ctx, ()| {
aggregate_values::<_, BA8, 8>(ctx, stream::empty().boxed(), 0)
aggregate_values::<_, BA8, 8>(ctx, stream::empty().boxed(), 0, false)
})
.await
.map(Result::unwrap)
Expand All @@ -505,7 +541,12 @@ pub mod tests {
let result = TestWorld::default()
.upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| {
let num_rows = inputs.len();
aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows)
aggregate_values::<_, BA8, 8>(
ctx,
stream::iter(inputs).boxed(),
num_rows,
false,
)
})
.await;

Expand All @@ -525,7 +566,12 @@ pub mod tests {
let _ = TestWorld::default()
.upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| {
let num_rows = inputs.len() + 1;
aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows)
aggregate_values::<_, BA8, 8>(
ctx,
stream::iter(inputs).boxed(),
num_rows,
false,
)
})
.await
.map(Result::unwrap)
Expand All @@ -547,7 +593,12 @@ pub mod tests {
let _ = TestWorld::default()
.upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| {
let num_rows = inputs.len() - 1;
aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows)
aggregate_values::<_, BA8, 8>(
ctx,
stream::iter(inputs).boxed(),
num_rows,
false,
)
})
.await
.map(Result::unwrap)
Expand Down Expand Up @@ -633,6 +684,7 @@ pub mod tests {
ctx,
stream::iter(inputs).boxed(),
num_rows,
false,
)
})
.await
Expand Down
176 changes: 168 additions & 8 deletions ipa-core/src/protocol/ipa_prf/boolean_ops/sigmoid.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
use std::{iter::repeat, ops::Not};
use std::{
iter::{repeat, zip},
ops::Not,
};

use futures::future::{try_join, try_join4, try_join5};
use futures::{
future::{try_join, try_join4, try_join5},
stream, StreamExt,
};

use super::multiplication::integer_mul;
use crate::{
error::Error,
ff::boolean::Boolean,
error::{Error, LengthError},
ff::{boolean::Boolean, boolean_array::BA8},
helpers::{repeat_n, TotalRecords},
protocol::{
basics::mul::SecureMul, boolean::step::ThirtyTwoBitStep, context::Context,
basics::mul::SecureMul,
boolean::{step::ThirtyTwoBitStep, NBitStep},
context::Context,
ipa_prf::aggregation::aggregate_values,
BooleanProtocols, RecordId,
},
secret_sharing::{replicated::semi_honest::AdditiveShare, BitDecomposed, FieldSimd},
secret_sharing::{
replicated::semi_honest::AdditiveShare, BitDecomposed, FieldSimd, TransposeFrom,
},
};

async fn a_times_b_and_not_b<C, const N: usize>(
Expand Down Expand Up @@ -158,13 +171,80 @@ where
]))
}

// edge_weights[0] holds all the edge weights coming _out_ from the first neuron in the previous layer
pub async fn one_layer<C, S, I, const N: usize>(
ctx: C,
last_layer_neurons: Vec<AdditiveShare<BA8>>,
edge_weights: I,
) -> Result<BitDecomposed<AdditiveShare<Boolean, N>>, Error>
where
C: Context,
S: NBitStep,
Boolean: FieldSimd<N>,
AdditiveShare<Boolean, N>: BooleanProtocols<C, N>,
I: IntoIterator<Item = BitDecomposed<AdditiveShare<Boolean, N>>>,
BitDecomposed<AdditiveShare<Boolean, N>>:
for<'a> TransposeFrom<&'a Vec<AdditiveShare<BA8>>, Error = LengthError>,
{
let multiplication_ctx = ctx.narrow("activation_times_edge_weight");

let contributions_per_neuron_in_last_layer: Vec<BitDecomposed<AdditiveShare<Boolean, N>>> = ctx
.parallel_join(zip(edge_weights, last_layer_neurons).enumerate().map(
|(i, (outbound_edge_weights, last_layer_neuron))| {
let repeated_neuron_activation = BitDecomposed::transposed_from(
&repeat_n(last_layer_neuron, N).collect::<Vec<_>>(),
)
.unwrap();
let c = multiplication_ctx.clone();
async move {
let lossless_result = integer_mul::<_, S, N>(
c,
RecordId::from(i),
&repeated_neuron_activation,
&outbound_edge_weights,
)
.await?;
// Neuron activtion is an 8-bit value meant to represent a
// fractional number in the range [0, 1)
// So after multiplying this value with the edge weight,
// we must shift 8 bits down to effectively divide by 256
let (_, top_8_bits) = lossless_result.split_at(8);
Ok::<_, Error>(top_8_bits)
}
},
))
.await?;

let total_input = aggregate_values::<_, BA8, N>(
ctx.narrow("aggregated_edge_weights"),
Box::pin(stream::iter(contributions_per_neuron_in_last_layer.into_iter()).map(Ok)),
N,
true,
)
.await?;

sigmoid::<_, N>(
ctx.narrow("sigmoid")
.set_total_records(TotalRecords::Indeterminate),
RecordId::FIRST,
&total_input,
)
.await
}

#[cfg(all(test, unit_test))]
mod test {
use std::num::TryFromIntError;
use std::{iter::zip, num::TryFromIntError};

use rand::{thread_rng, Rng};

use super::one_layer;
use crate::{
ff::{boolean_array::BA8, U128Conversions},
protocol::{context::Context, ipa_prf::boolean_ops::sigmoid::sigmoid, RecordId},
protocol::{
boolean::step::DefaultBitStep, context::Context,
ipa_prf::boolean_ops::sigmoid::sigmoid, RecordId,
},
secret_sharing::{BitDecomposed, SharedValue, TransposeFrom},
test_executor::run,
test_fixture::{Reconstruct, Runner, TestWorld},
Expand Down Expand Up @@ -237,4 +317,84 @@ mod test {
}
});
}

#[test]
#[allow(clippy::cast_precision_loss)]
fn semi_honest_neural_network() {
run(|| async move {
let world = TestWorld::default();

let mut rng = thread_rng();

let edge_weights_matrix = (0..32)
.map(|i| {
(0..32)
.map(|j| {
// offset is in the range [-16, 16)
let offset = (3 * i + 5 * j) % 32 - 16;
let modulo = (256 + offset) % 256;
BA8::truncate_from(modulo as u128)
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let prev_neurons = (0..32).map(|_| rng.gen::<BA8>()).collect::<Vec<_>>();

let result: Vec<BA8> = world
.upgraded_semi_honest(
(
edge_weights_matrix
.clone()
.into_iter()
.map(|x| x.into_iter()),
prev_neurons.clone().into_iter(),
),
|ctx, (edge_weights, prev_neurons)| async move {
let matrix_of_edge_weights = edge_weights
.iter()
.map(|chunk| BitDecomposed::transposed_from(chunk).unwrap());
let result = one_layer::<_, DefaultBitStep, _, 32>(
ctx.set_total_records(32),
prev_neurons,
matrix_of_edge_weights,
)
.await
.unwrap();

Vec::transposed_from(&result).unwrap()
},
)
.await
.reconstruct();

let expected_activations = zip(edge_weights_matrix, prev_neurons)
.fold([0; 32], |mut acc, (edge_weights, n)| {
let contributions_from_neuron = edge_weights.into_iter().map(|e| {
let lossless = as_i128(e) * i128::try_from(n.as_u128()).unwrap();
lossless >> 8
});

acc.iter_mut()
.zip(contributions_from_neuron)
.for_each(|(a, c)| *a += c);
acc
})
.map(|total_input| {
(
total_input,
piecewise_linear_sigmoid_approximation(total_input).unwrap(),
)
});

for ((total_input, expected_activation), actual_result) in
expected_activations.iter().zip(result)
{
println!(
"total_input: {:?}, expected_activation: {:?}, actual_result: {:?}",
total_input, expected_activation, actual_result
);
assert_eq!(actual_result.as_u128(), *expected_activation);
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ where
seq_join(sh_ctx.active_work(), stream::iter(chunked_user_results)).try_flatten_iters(),
);
let aggregated_result: BitDecomposed<AdditiveShare<Boolean, B>> =
aggregate_values::<_, HV, B>(binary_m_ctx, flattened_stream, num_outputs).await?;
aggregate_values::<_, HV, B>(binary_m_ctx, flattened_stream, num_outputs, false).await?;

let transposed_aggregated_result: Vec<Replicated<HV>> =
Vec::transposed_from(&aggregated_result)?;
Expand Down