Skip to content

Commit

Permalink
Add reach only shuffle and aggregation phase. (#1600)
Browse files Browse the repository at this point in the history
  • Loading branch information
ple13 authored May 30, 2024
1 parent 80ecbc7 commit bb94176
Show file tree
Hide file tree
Showing 15 changed files with 1,166 additions and 122 deletions.
2 changes: 1 addition & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ bazel_dep(
)
bazel_dep(
name = "any-sketch",
version = "0.8.1",
version = "0.9.0",
repo_name = "any_sketch",
)
bazel_dep(
Expand Down
20 changes: 10 additions & 10 deletions MODULE.bazel.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,25 @@ absl::Status VerifySketchParameters(const ShareShuffleSketchParams& params) {
return absl::OkStatus();
}

// Checks if val is a prime.
//
// This algorithm has the computation complexity of O(sqrt(val)).
bool IsPrime(int val) {
if (val <= 1) {
return false;
}
for (int i = 2; i * i <= val; i++) {
if (val % i == 0) {
return false;
}
}
return true;
}

} // namespace

absl::StatusOr<CompleteShufflePhaseResponse> CompleteShufflePhase(
absl::StatusOr<CompleteShufflePhaseResponse>
CompleteReachAndFrequencyShufflePhase(
const CompleteShufflePhaseRequest& request) {
StartedThreadCpuTimer timer;
CompleteShufflePhaseResponse response;
Expand All @@ -100,12 +116,9 @@ absl::StatusOr<CompleteShufflePhaseResponse> CompleteShufflePhase(
"actual is $2.",
i, sketch_size, share_vector.size()));
}
for (int j = 0; j < sketch_size; j++) {
// It's guaranteed that (combined_sketch[j] + share_vector[j]) is
// not greater than 2^{32}-1.
combined_sketch[j] = (combined_sketch[j] + share_vector[j]) %
request.sketch_params().ring_modulus();
}
ASSIGN_OR_RETURN(combined_sketch,
VectorAddMod(combined_sketch, share_vector,
request.sketch_params().ring_modulus()));
}

ASSIGN_OR_RETURN(PrngSeed seed,
Expand All @@ -126,7 +139,8 @@ absl::StatusOr<CompleteShufflePhaseResponse> CompleteShufflePhase(

// Generates local noise registers.
ASSIGN_OR_RETURN(std::vector<uint32_t> noise_registers,
GenerateNoiseRegisters(request.sketch_params(), *noiser));
GenerateReachAndFrequencyNoiseRegisters(
request.sketch_params(), *noiser));

// Both workers generate common random vectors from the common random seed.
// rand_vec_1 || rand_vec_2 <-- PRNG(seed).
Expand Down Expand Up @@ -190,11 +204,142 @@ absl::StatusOr<CompleteShufflePhaseResponse> CompleteShufflePhase(
return response;
}

absl::StatusOr<CompleteAggregationPhaseResponse> CompleteAggregationPhase(
absl::StatusOr<CompleteShufflePhaseResponse> CompleteReachOnlyShufflePhase(
const CompleteShufflePhaseRequest& request) {
StartedThreadCpuTimer timer;
CompleteShufflePhaseResponse response;

RETURN_IF_ERROR(VerifySketchParameters(request.sketch_params()));

// Verify that the ring modulus is a prime.
if (!IsPrime(request.sketch_params().ring_modulus())) {
return absl::InvalidArgumentError("The ring modulus must be a prime.");
}

if (request.sketch_shares().empty()) {
return absl::InvalidArgumentError("Sketch shares must not be empty.");
}

const int sketch_size = request.sketch_params().register_count();

// Combines the input shares.
std::vector<uint32_t> combined_sketch(sketch_size, 0);
for (int i = 0; i < request.sketch_shares().size(); i++) {
ASSIGN_OR_RETURN(
std::vector<uint32_t> share_vector,
GetShareVectorFromSketchShare(request.sketch_params(),
request.sketch_shares().Get(i)));
if (share_vector.size() != sketch_size) {
return absl::InvalidArgumentError(absl::Substitute(
"The $0-th sketch share has invalid size. Expect $1 but the "
"actual is $2.",
i, sketch_size, share_vector.size()));
}
ASSIGN_OR_RETURN(combined_sketch,
VectorAddMod(combined_sketch, share_vector,
request.sketch_params().ring_modulus()));
}

ASSIGN_OR_RETURN(PrngSeed seed,
GetPrngSeedFromString(request.common_random_seed()));
// Initializes the pseudo-random generator with the common random seed.
// The PRNG will generate random shares for the noise registers (if needed)
// and the seed that is used for shuffling.
ASSIGN_OR_RETURN(std::unique_ptr<UniformPseudorandomGenerator> prng,
CreatePrngFromSeed(seed));

// Sample a vector r of random values in [1, modulus).
ASSIGN_OR_RETURN(std::vector<uint32_t> r,
prng->GenerateNonZeroUniformRandomRange(
sketch_size, request.sketch_params().ring_modulus()));

// Transform share of non-zero registers to share of a non-zero random value.
for (int j = 0; j < sketch_size; j++) {
combined_sketch[j] = uint64_t{combined_sketch[j]} * uint64_t{r[j]} %
request.sketch_params().ring_modulus();
}

// Adds noise registers to the combined input share.
if (request.has_dp_params()) {
// Initializes the noiser, which will generate blind histogram noise to hide
// the actual frequency histogram counts.
auto noiser = GetBlindHistogramNoiser(request.dp_params(),
/*contributors_count=*/2,
request.noise_mechanism());

// Generates local noise registers.
ASSIGN_OR_RETURN(
std::vector<uint32_t> noise_registers,
GenerateReachOnlyNoiseRegisters(request.sketch_params(), *noiser));

// Both workers generate common random vectors from the common random seed.
// rand_vec_1 || rand_vec_2 <-- PRNG(seed).
ASSIGN_OR_RETURN(
std::vector<uint32_t> rand_vec_1,
prng->GenerateUniformRandomRange(
noise_registers.size(), request.sketch_params().ring_modulus()));
ASSIGN_OR_RETURN(
std::vector<uint32_t> rand_vec_2,
prng->GenerateUniformRandomRange(
noise_registers.size(), request.sketch_params().ring_modulus()));

// Generates local noise register shares using the common random vectors.
// Worker 1 obtains shares:
// {first_local_noise_share || second_local_noise_share}
// = {(noise_registers_1 - rand_vec_1) || rand_vec_2}.
// Worker 2 obtains shares:
// {first_local_noise_share ||second_local_noise_share}
// = {rand_vec_1 || (noise_registers_2 - rand_vec_2)}.
std::vector<uint32_t> first_local_noise_share;
std::vector<uint32_t> second_local_noise_share;
if (request.order() == CompleteShufflePhaseRequest::FIRST) {
ASSIGN_OR_RETURN(first_local_noise_share,
VectorSubMod(noise_registers, rand_vec_1,
request.sketch_params().ring_modulus()));
second_local_noise_share = std::move(rand_vec_2);
} else if (request.order() == CompleteShufflePhaseRequest::SECOND) {
first_local_noise_share = std::move(rand_vec_1);
ASSIGN_OR_RETURN(second_local_noise_share,
VectorSubMod(noise_registers, rand_vec_2,
request.sketch_params().ring_modulus()));
} else {
return absl::InvalidArgumentError(
"Non aggregator order must be specified.");
}

// Appends the first noise share to the combined sketch share.
combined_sketch.insert(combined_sketch.end(),
first_local_noise_share.begin(),
first_local_noise_share.end());
// Appends the second noise share to the combined sketch share.
combined_sketch.insert(combined_sketch.end(),
second_local_noise_share.begin(),
second_local_noise_share.end());
}

// Generates shuffle seed from common random seed.
ASSIGN_OR_RETURN(
std::vector<unsigned char> shuffle_seed_vec,
prng->GeneratePseudorandomBytes(kBytesPerAes256Key + kBytesPerAes256Iv));
ASSIGN_OR_RETURN(PrngSeed shuffle_seed,
GetPrngSeedFromCharVector(shuffle_seed_vec));
// Shuffle the shares.
RETURN_IF_ERROR(SecureShuffleWithSeed(combined_sketch, shuffle_seed));

response.mutable_combined_sketch()->Add(combined_sketch.begin(),
combined_sketch.end());
*response.mutable_elapsed_cpu_duration() =
google::protobuf::util::TimeUtil::MillisecondsToDuration(
timer.ElapsedMillis());
return response;
}

absl::StatusOr<CompleteAggregationPhaseResponse>
CompleteReachAndFrequencyAggregationPhase(
const CompleteAggregationPhaseRequest& request) {
StartedThreadCpuTimer timer;
CompleteAggregationPhaseResponse response;

RETURN_IF_ERROR(VerifySketchParameters(request.sketch_params()));
if (request.sketch_shares().size() != kWorkerCount) {
return absl::InvalidArgumentError(
"The number of share vectors must be equal to the number of "
Expand All @@ -214,20 +359,18 @@ absl::StatusOr<CompleteAggregationPhaseResponse> CompleteAggregationPhase(
// frequency_histogram[i] = the number of times value i occurs where
// i in {0, ..., maximum_frequency-1}.
absl::flat_hash_map<int, int64_t> frequency_histogram;
for (auto x : combined_sketch) {
if (x > request.sketch_params().maximum_combined_frequency() &&
x != (request.sketch_params().ring_modulus() - 1)) {
for (const auto reg : combined_sketch) {
if (reg > request.sketch_params().maximum_combined_frequency() &&
reg != (request.sketch_params().ring_modulus() - 1)) {
return absl::InternalError(absl::Substitute(
"The combined register value, which is $0, is not valid. It must be "
"either the "
"sentinel value, which is $1, or less that or equal to the combined "
"maximum "
"frequency, which is $2.",
x, request.sketch_params().ring_modulus() - 1,
"either the sentinel value, which is $1, or less that or equal to "
"the combined maximum frequency, which is $2.",
reg, request.sketch_params().ring_modulus() - 1,
request.sketch_params().maximum_combined_frequency()));
}
if (x < maximum_frequency) {
frequency_histogram[x]++;
if (reg < maximum_frequency) {
frequency_histogram[reg]++;
}
}

Expand Down Expand Up @@ -302,7 +445,7 @@ absl::StatusOr<CompleteAggregationPhaseResponse> CompleteAggregationPhase(
return absl::InvalidArgumentError(
"There is neither actual data nor effective noise in the request.");
}
// Estimates reach using the number of non-empty buckets and the vid sampling
// Estimates reach using the number of non-empty buckets and the VID sampling
// interval width.
ASSIGN_OR_RETURN(int64_t reach,
EstimateReach(non_empty_register_count,
Expand All @@ -324,4 +467,57 @@ absl::StatusOr<CompleteAggregationPhaseResponse> CompleteAggregationPhase(
return response;
}

absl::StatusOr<CompleteAggregationPhaseResponse>
CompleteReachOnlyAggregationPhase(
const CompleteAggregationPhaseRequest& request) {
StartedThreadCpuTimer timer;
CompleteAggregationPhaseResponse response;
RETURN_IF_ERROR(VerifySketchParameters(request.sketch_params()));
if (request.sketch_shares().size() != kWorkerCount) {
return absl::InvalidArgumentError(
"The number of share vectors must be equal to the number of "
"non-aggregators.");
}
ASSIGN_OR_RETURN(
std::vector<uint32_t> combined_sketch,
CombineSketchShares(request.sketch_params(), request.sketch_shares()));

// Count the non-zero registers.
int64_t non_empty_register_count = 0;
for (const auto reg : combined_sketch) {
if (reg != 0) {
non_empty_register_count++;
}
}

// Adjusts the non empty register count according the noise baseline.
if (request.has_dp_params()) {
auto noiser = GetBlindHistogramNoiser(request.dp_params(), kWorkerCount,
request.noise_mechanism());
int64_t noise_baseline = noiser->options().shift_offset * kWorkerCount;

// Removes the noise baseline from the non empty register count.
non_empty_register_count -= noise_baseline;

// Ensures that non_empty_register_count is at least 0.
non_empty_register_count = std::max(0L, non_empty_register_count);

// Ensures that non_empty_register_count is at most the input sketch size.
non_empty_register_count = std::min(
request.sketch_params().register_count(), non_empty_register_count);
}

// Estimates reach using the number of non-empty buckets and the VID sampling
// interval width.
ASSIGN_OR_RETURN(int64_t reach,
EstimateReach(non_empty_register_count,
request.vid_sampling_interval_width()));

response.set_reach(reach);
*response.mutable_elapsed_cpu_duration() =
google::protobuf::util::TimeUtil::MillisecondsToDuration(
timer.ElapsedMillis());
return response;
}

} // namespace wfa::measurement::internal::duchy::protocol::share_shuffle
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,19 @@ using ::wfa::measurement::internal::duchy::protocol::
using ::wfa::measurement::internal::duchy::protocol::
CompleteShufflePhaseResponse;

absl::StatusOr<CompleteShufflePhaseResponse> CompleteShufflePhase(
absl::StatusOr<CompleteShufflePhaseResponse>
CompleteReachAndFrequencyShufflePhase(
const CompleteShufflePhaseRequest& request);

absl::StatusOr<CompleteAggregationPhaseResponse> CompleteAggregationPhase(
absl::StatusOr<CompleteShufflePhaseResponse> CompleteReachOnlyShufflePhase(
const CompleteShufflePhaseRequest& request);

absl::StatusOr<CompleteAggregationPhaseResponse>
CompleteReachAndFrequencyAggregationPhase(
const CompleteAggregationPhaseRequest& request);

absl::StatusOr<CompleteAggregationPhaseResponse>
CompleteReachOnlyAggregationPhase(
const CompleteAggregationPhaseRequest& request);

} // namespace wfa::measurement::internal::duchy::protocol::share_shuffle
Expand Down
Loading

0 comments on commit bb94176

Please sign in to comment.