Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for reach only HMSS. #1655

Merged
merged 15 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ import org.wfanet.measurement.internal.duchy.config.RoleInComputation.AGGREGATOR
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.FIRST_NON_AGGREGATOR
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.SECOND_NON_AGGREGATOR
import org.wfanet.measurement.internal.duchy.protocol.CompleteAggregationPhaseRequestKt
import org.wfanet.measurement.internal.duchy.protocol.CompleteAggregationPhaseResponse
import org.wfanet.measurement.internal.duchy.protocol.CompleteShufflePhaseRequest
import org.wfanet.measurement.internal.duchy.protocol.CompleteShufflePhaseRequestKt
import org.wfanet.measurement.internal.duchy.protocol.CompleteShufflePhaseRequestKt.sketchShare
import org.wfanet.measurement.internal.duchy.protocol.CompleteShufflePhaseResponse
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.AggregationPhaseInput
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.ShufflePhaseInput
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage
Expand Down Expand Up @@ -376,6 +378,14 @@ class HonestMajorityShareShuffleMill(
}

private suspend fun shufflePhase(token: ComputationToken): ComputationToken {
val publicApiVersion =
Version.fromString(token.computationDetails.kingdomComputation.publicApiVersion)
val measurementSpec =
when (publicApiVersion) {
Version.V2_ALPHA ->
MeasurementSpec.parseFrom(token.computationDetails.kingdomComputation.measurementSpec)
}

val requisitions = token.requisitionsList.sortedBy { it.externalKey.externalRequisitionId }

val requisitionBlobs = dataClients.readRequisitionBlobs(token)
Expand All @@ -392,8 +402,12 @@ class HonestMajorityShareShuffleMill(
} else {
CompleteShufflePhaseRequest.NonAggregatorOrder.SECOND
}
reachDpParams = hmss.parameters.reachDpParams
frequencyDpParams = hmss.parameters.frequencyDpParams
if (hmss.parameters.hasReachDpParams()) {
reachDpParams = hmss.parameters.reachDpParams
}
if (hmss.parameters.hasFrequencyDpParams()) {
frequencyDpParams = hmss.parameters.frequencyDpParams
}
noiseMechanism = hmss.parameters.noiseMechanism

val registerCounts = mutableListOf<Long>()
Expand All @@ -416,8 +430,6 @@ class HonestMajorityShareShuffleMill(
secretSeeds.find { it.requisitionId == requisitionId }
?: error("Neither blob and seed received for requisition $requisitionId")

val publicApiVersion =
Version.fromString(token.computationDetails.kingdomComputation.publicApiVersion)
val seed =
verifySecretSeed(secretSeed, hmss.encryptionKeyPair.privateKeyId, publicApiVersion)

Expand All @@ -430,9 +442,15 @@ class HonestMajorityShareShuffleMill(
sketchParams = hmss.parameters.sketchParams.copy { registerCount = registerCounts.first() }
}

val result =
val result: CompleteShufflePhaseResponse =
logWallClockDuration(token, CRYPTO_WALL_CLOCK_DURATION, cryptoWallClockDurationHistogram) {
cryptoWorker.completeReachAndFrequencyShufflePhase(request)
when (val measurementType = measurementSpec.measurementTypeCase) {
MeasurementSpec.MeasurementTypeCase.REACH ->
cryptoWorker.completeReachOnlyShufflePhase(request)
MeasurementSpec.MeasurementTypeCase.REACH_AND_FREQUENCY ->
cryptoWorker.completeReachAndFrequencyShufflePhase(request)
else -> error("Unsupported measurement type $measurementType")
}
}

logStageDurationMetric(
Expand Down Expand Up @@ -463,23 +481,32 @@ class HonestMajorityShareShuffleMill(
}

private suspend fun aggregationPhase(token: ComputationToken): ComputationToken {
val publicApiVersion =
Version.fromString(token.computationDetails.kingdomComputation.publicApiVersion)
val measurementSpec =
when (publicApiVersion) {
Version.V2_ALPHA ->
MeasurementSpec.parseFrom(token.computationDetails.kingdomComputation.measurementSpec)
}

val aggregationPhaseInputs = getAggregationPhaseInputs(token)

val request = completeAggregationPhaseRequest {
val hmss = token.computationDetails.honestMajorityShareShuffle
sketchParams = hmss.parameters.sketchParams
maximumFrequency = hmss.parameters.maximumFrequency
val publicApiVersion =
Version.fromString(token.computationDetails.kingdomComputation.publicApiVersion)

when (publicApiVersion) {
Version.V2_ALPHA -> {
val measurementSpec =
MeasurementSpec.parseFrom(token.computationDetails.kingdomComputation.measurementSpec)
vidSamplingIntervalWidth = measurementSpec.vidSamplingInterval.width
}
}
reachDpParams = hmss.parameters.reachDpParams
frequencyDpParams = hmss.parameters.frequencyDpParams
if (hmss.parameters.hasReachDpParams()) {
reachDpParams = hmss.parameters.reachDpParams
}
if (hmss.parameters.hasFrequencyDpParams()) {
frequencyDpParams = hmss.parameters.frequencyDpParams
}
noiseMechanism = hmss.parameters.noiseMechanism

for (input in aggregationPhaseInputs) {
Expand All @@ -488,9 +515,15 @@ class HonestMajorityShareShuffleMill(
}
}

val result =
val result: CompleteAggregationPhaseResponse =
logWallClockDuration(token, CRYPTO_WALL_CLOCK_DURATION, cryptoWallClockDurationHistogram) {
cryptoWorker.completeReachAndFrequencyAggregationPhase(request)
when (val measurementType = measurementSpec.measurementTypeCase) {
MeasurementSpec.MeasurementTypeCase.REACH ->
cryptoWorker.completeReachOnlyAggregationPhase(request)
MeasurementSpec.MeasurementTypeCase.REACH_AND_FREQUENCY ->
cryptoWorker.completeReachAndFrequencyAggregationPhase(request)
else -> error("Unsupported measurement type $measurementType")
}
}

logStageDurationMetric(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,12 @@ interface HonestMajorityShareShuffleCryptor {
fun completeReachAndFrequencyAggregationPhase(
request: CompleteAggregationPhaseRequest
): CompleteAggregationPhaseResponse

fun completeReachOnlyShufflePhase(
request: CompleteShufflePhaseRequest
): CompleteShufflePhaseResponse

fun completeReachOnlyAggregationPhase(
request: CompleteAggregationPhaseRequest
): CompleteAggregationPhaseResponse
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ class JniHonestMajorityShareShuffleCryptor : HonestMajorityShareShuffleCryptor {
)
}

override fun completeReachOnlyShufflePhase(
request: CompleteShufflePhaseRequest
): CompleteShufflePhaseResponse {
return CompleteShufflePhaseResponse.parseFrom(
HonestMajorityShareShuffleUtility.completeReachOnlyShufflePhase(request.toByteArray())
)
}

override fun completeReachOnlyAggregationPhase(
request: CompleteAggregationPhaseRequest
): CompleteAggregationPhaseResponse {
return CompleteAggregationPhaseResponse.parseFrom(
HonestMajorityShareShuffleUtility.completeReachOnlyAggregationPhase(request.toByteArray())
)
}

companion object {
init {
System.loadLibrary("honest_majority_share_shuffle_utility")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,27 @@ import org.wfanet.measurement.eventdataprovider.noiser.GaussianNoiser
* TODO(@ple13): Refactor the AcdpParamsConverter object to take the noise type into account.
*/
object AcdpParamsConverter {
/** Memoized computation of LLV2 based MPC ACDP params conversion results. */
private val llv2AcdpParamsConversionResults =
ConcurrentHashMap<Llv2AcdpParamsConversionKey, AcdpCharge>()
/** Memoized computation of Mpc based MPC ACDP params conversion results. */
private val MpcAcdpParamsConversionResults =
ConcurrentHashMap<MpcAcdpParamsConversionKey, AcdpCharge>()

/**
* Convert LLV2 based MPC per-query DP charge(epsilon, delta) to ACDP charge(rho, theta). The
* computation result is memoized.
* Convert MPC per-query DP charge(epsilon, delta) to ACDP charge(rho, theta). The computation
* result is memoized.
*
* @param privacyParams Internal DifferentialPrivacyParams.
* @param contributorCount Number of Duchies.
* @return ACDP charge(rho, theta).
*/
fun getLlv2AcdpCharge(privacyParams: DpParams, contributorCount: Int): AcdpCharge {
fun getMpcAcdpCharge(privacyParams: DpParams, contributorCount: Int): AcdpCharge {
require(privacyParams.epsilon > 0 && privacyParams.delta > 0 && contributorCount > 0) {
"Epsilon, delta, and contributor count must be positive, but got: epsilon=${privacyParams.epsilon} delta=${privacyParams.delta} contributorCount=$contributorCount"
}

return llv2AcdpParamsConversionResults.getOrPut(
Llv2AcdpParamsConversionKey(privacyParams, contributorCount)
return MpcAcdpParamsConversionResults.getOrPut(
MpcAcdpParamsConversionKey(privacyParams, contributorCount)
) {
computeLlv2RhoAndTheta(privacyParams, contributorCount)
computeMpcRhoAndTheta(privacyParams, contributorCount)
}
}

Expand All @@ -79,9 +79,9 @@ object AcdpParamsConverter {
* The sum of delta1 and delta2 should be delta. In practice, set delta1 = delta2 = 0.5 * delta
* for simplicity.
*/
private fun getLlv2Deltas(delta: Double): MpcDeltas = MpcDeltas(0.5 * delta, 0.5 * delta)
private fun getMpcDeltas(delta: Double): MpcDeltas = MpcDeltas(0.5 * delta, 0.5 * delta)

fun computeLlv2SigmaDistributedDiscreteGaussian(
fun computeMpcSigmaDistributedDiscreteGaussian(
privacyParams: DpParams,
contributorCount: Int,
): Double {
Expand All @@ -92,7 +92,7 @@ object AcdpParamsConverter {
// an approximation for discrete Gaussian noise here. It generally works for
// epsilon <= 1 but not epsilon > 1

val deltas = getLlv2Deltas(privacyParams.delta)
val deltas = getMpcDeltas(privacyParams.delta)
val delta1 = deltas.delta1

// This simple formula to derive sigmaDistributed is valid only for
Expand All @@ -103,7 +103,7 @@ object AcdpParamsConverter {
return sigma / sqrt(contributorCount.toDouble())
}

private fun computeLlv2MuDiscreteGaussian(
private fun computeMpcMuDiscreteGaussian(
privacyParams: DpParams,
sigmaDistributed: Double,
contributorCount: Int,
Expand All @@ -112,18 +112,18 @@ object AcdpParamsConverter {
// The selection of these two parameters have the following effect: setting delta2 larger
// results in smaller truncation threshold but larger noise standard
// deviation.
val deltas = getLlv2Deltas(privacyParams.delta)
val deltas = getMpcDeltas(privacyParams.delta)
val delta2 = deltas.delta2

return ceil(
sigmaDistributed * sqrt(2 * ln(contributorCount * (1 + exp(privacyParams.epsilon)) / delta2))
)
}

private fun computeLlv2RhoAndTheta(privacyParams: DpParams, contributorCount: Int): AcdpCharge {
private fun computeMpcRhoAndTheta(privacyParams: DpParams, contributorCount: Int): AcdpCharge {
val sigmaDistributed =
computeLlv2SigmaDistributedDiscreteGaussian(privacyParams, contributorCount)
val mu = computeLlv2MuDiscreteGaussian(privacyParams, sigmaDistributed, contributorCount)
computeMpcSigmaDistributedDiscreteGaussian(privacyParams, contributorCount)
val mu = computeMpcMuDiscreteGaussian(privacyParams, sigmaDistributed, contributorCount)

// For reach and frequency, the sensitivity Delta should be 1.
val sensitivity = 1.0
Expand All @@ -147,7 +147,7 @@ object AcdpParamsConverter {
}
}

private data class Llv2AcdpParamsConversionKey(
private data class MpcAcdpParamsConversionKey(
val privacyParams: DpParams,
val contributorCount: Int,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ object PrivacyQueryMapper {
private const val SENSITIVITY = 1.0

/**
* Constructs a pbm specific [AcdpQuery] from given proto messages for LiquidLegionsV2 protocol.
* Constructs a pbm specific [AcdpQuery] from given proto messages for Mpc protocols.
*
* @param reference representing the reference key and if the charge is a refund.
* @param measurementSpec The measurementSpec protobuf that is associated with the query. The VID
Expand All @@ -42,7 +42,7 @@ object PrivacyQueryMapper {
* if an error occurs in handling this request. Possible exceptions could include running out of
* privacy budget or a failure to commit the transaction to the database.
*/
fun getLiquidLegionsV2AcdpQuery(
fun getMpcAcdpQuery(
reference: Reference,
measurementSpec: MeasurementSpec,
eventSpecs: Iterable<RequisitionSpec.EventGroupEntry.Value>,
Expand All @@ -51,7 +51,7 @@ object PrivacyQueryMapper {
val acdpCharge =
when (measurementSpec.measurementTypeCase) {
MeasurementTypeCase.REACH -> {
AcdpParamsConverter.getLlv2AcdpCharge(
AcdpParamsConverter.getMpcAcdpCharge(
DpParams(
measurementSpec.reach.privacyParams.epsilon,
measurementSpec.reach.privacyParams.delta,
Expand All @@ -61,15 +61,15 @@ object PrivacyQueryMapper {
}
MeasurementTypeCase.REACH_AND_FREQUENCY -> {
val acdpChargeForReach =
AcdpParamsConverter.getLlv2AcdpCharge(
AcdpParamsConverter.getMpcAcdpCharge(
DpParams(
measurementSpec.reachAndFrequency.reachPrivacyParams.epsilon,
measurementSpec.reachAndFrequency.reachPrivacyParams.delta,
),
contributorCount,
)
val acdpChargeForFrequency =
AcdpParamsConverter.getLlv2AcdpCharge(
AcdpParamsConverter.getMpcAcdpCharge(
DpParams(
measurementSpec.reachAndFrequency.frequencyPrivacyParams.epsilon,
measurementSpec.reachAndFrequency.frequencyPrivacyParams.delta,
Expand Down Expand Up @@ -98,70 +98,6 @@ object PrivacyQueryMapper {
)
}

/**
* Constructs a pbm specific [AcdpQuery] from given proto messages for Hmss protocol.
*
* @param reference representing the reference key and if the charge is a refund.
* @param measurementSpec The measurementSpec protobuf that is associated with the query. The VID
* sampling interval is obtained from this.
* @param eventSpecs event specs from the Requisition. The date range and demo groups are obtained
* from this.
* @param contributorCount number of Duchies
* @throws
* org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.PrivacyBudgetManagerException
* if an error occurs in handling this request. Possible exceptions could include running out of
* privacy budget or a failure to commit the transaction to the database.
*/
fun getHmssAcdpQuery(
reference: Reference,
measurementSpec: MeasurementSpec,
eventSpecs: Iterable<RequisitionSpec.EventGroupEntry.Value>,
contributorCount: Int,
): AcdpQuery {
val acdpCharge =
when (measurementSpec.measurementTypeCase) {
// TODO(@ple13): Add support for reach-only.
MeasurementTypeCase.REACH_AND_FREQUENCY -> {
// Uses the function getLlv2AcdpCharge to compute the ACDP charge for this query as HMSS
// and LLV2 use the same approach when adding differential private noise.
val acdpChargeForReach =
AcdpParamsConverter.getLlv2AcdpCharge(
DpParams(
measurementSpec.reachAndFrequency.reachPrivacyParams.epsilon,
measurementSpec.reachAndFrequency.reachPrivacyParams.delta,
),
contributorCount,
)
val acdpChargeForFrequency =
AcdpParamsConverter.getLlv2AcdpCharge(
DpParams(
measurementSpec.reachAndFrequency.frequencyPrivacyParams.epsilon,
measurementSpec.reachAndFrequency.frequencyPrivacyParams.delta,
),
contributorCount,
)
AcdpCharge(
acdpChargeForReach.rho + acdpChargeForFrequency.rho,
acdpChargeForReach.theta + acdpChargeForFrequency.theta,
)
}
else ->
throw IllegalArgumentException(
"Measurement type ${measurementSpec.measurementTypeCase} is not supported in getHmssAcdpQuery()"
)
}

return AcdpQuery(
reference,
LandscapeMask(
eventSpecs.map { EventGroupSpec(it.filter.expression, it.collectionInterval.toRange()) },
measurementSpec.vidSamplingInterval.start,
measurementSpec.vidSamplingInterval.width,
),
acdpCharge,
)
}

/**
* Constructs a pbm specific [AcdpQuery] from given proto messages for direct measurements.
*
Expand Down
Loading
Loading