Skip to content

Commit

Permalink
Create Population Requisition Fulfiller (#1527)
Browse files Browse the repository at this point in the history
Create Population Requisition Fulfiller w/ tests

---------

Co-authored-by: jojijac0b <[email protected]>
  • Loading branch information
2 people authored and ple13 committed Aug 16, 2024
1 parent e29cfbf commit 4f6254c
Show file tree
Hide file tree
Showing 19 changed files with 1,680 additions and 221 deletions.
5 changes: 5 additions & 0 deletions src/main/k8s/testing/secretfiles/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ filegroup(
srcs = [
"aggregator_root.pem",
"kingdom_root.pem",
"pdp1_root.pem",
] + glob(["*edp*_root.pem"]),
)

Expand Down Expand Up @@ -216,6 +217,10 @@ SECRET_FILES = [
"exchange_workflow.textproto",
"reporting_tls.key",
"reporting_tls.pem",
"pdp1_cs_cert.der",
"pdp1_cs_private.der",
"pdp1_enc_private.tink",
"pdp1_enc_public.tink",
]

filegroup(
Expand Down
Binary file added src/main/k8s/testing/secretfiles/pdp1_cs_cert.der
Binary file not shown.
Binary file added src/main/k8s/testing/secretfiles/pdp1_cs_private.der
Binary file not shown.
Binary file not shown.
Binary file not shown.
5 changes: 5 additions & 0 deletions src/main/k8s/testing/secretfiles/pdp1_root.key
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgkxY2b6z6khhcfMse
mahhvwEcV7iNwmcAhmIdleR7goihRANCAAQCiTgBO2Qe6kSVcdP51lDa13Q7hxoP
pDvgZa07LT26/apLhGADvKajOT6nfpeXjnUa+myjuhlP25mY24Lh/Dgq
-----END PRIVATE KEY-----
12 changes: 12 additions & 0 deletions src/main/k8s/testing/secretfiles/pdp1_root.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-----BEGIN CERTIFICATE-----
MIIB2zCCAYGgAwIBAgIUOUhTyf/lbnXnRh41LP3m8D4GpNEwCgYIKoZIzj0EAwIw
KTEVMBMGA1UECgwMSGFsbyBDTU0gRGV2MRAwDgYDVQQDDAdQZHAxIENBMB4XDTI0
MDYwNTE5MzYzMVoXDTM0MDYwMzE5MzYzMVowKTEVMBMGA1UECgwMSGFsbyBDTU0g
RGV2MRAwDgYDVQQDDAdQZHAxIENBMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE
Aok4ATtkHupElXHT+dZQ2td0O4caD6Q74GWtOy09uv2qS4RgA7ymozk+p36Xl451
Gvpso7oZT9uZmNuC4fw4KqOBhjCBgzAdBgNVHQ4EFgQUlQzFwajKpHfpj+5I8eFe
OMzfrbMwHwYDVR0jBBgwFoAUlQzFwajKpHfpj+5I8eFeOMzfrbMwDwYDVR0TAQH/
BAUwAwEB/zALBgNVHQ8EBAMCAYYwIwYDVR0RBBwwGoIYY2EucGRwMS5kZXYuaGFs
by1jbW0ub3JnMAoGCCqGSM49BAMCA0gAMEUCIHGO5/B9qsb+u/0s7cCoEiD7go2Z
iUJsy2LH69LJORrmAiEAmp9zPpNcE63MT0eNA3hU5fZXE34LqHdnRq+dNq9YWCo=
-----END CERTIFICATE-----
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,11 @@ object PopulationSpecValidator {
return details
}
}

/**
* Returns the size of a [VidRange] by calculating the difference between the start and end of the
* range.
*/
fun VidRange.size(): Long {
return this.endVidInclusive - this.startVid + 1
}
32 changes: 32 additions & 0 deletions src/main/kotlin/org/wfanet/measurement/dataprovider/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library")

package(
default_visibility = [
"//src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider:__subpackages__",
"//src/main/kotlin/org/wfanet/measurement/populationdataprovider:__subpackages__",
],
)

kt_jvm_library(
name = "requisition_fulfiller",
srcs = ["RequisitionFulfiller.kt"],
deps = [
"//src/main/kotlin/org/wfanet/measurement/api/v2alpha:packed_messages",
"//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key",
"//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:crypto_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:direct_computation_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumers_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:measurements_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:requisitions_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto:key_storage",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/throttler",
"@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/common:verification_exception",
"@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/dataprovider",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
// Copyright 2024 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.wfanet.measurement.dataprovider

import com.google.protobuf.ByteString
import com.google.protobuf.kotlin.unpack
import io.grpc.StatusException
import java.security.GeneralSecurityException
import java.security.SignatureException
import java.security.cert.CertPathValidatorException
import java.security.cert.X509Certificate
import java.util.logging.Level
import java.util.logging.Logger
import org.wfanet.measurement.api.v2alpha.Certificate
import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub
import org.wfanet.measurement.api.v2alpha.DataProviderCertificateKey
import org.wfanet.measurement.api.v2alpha.EncryptedMessage
import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey
import org.wfanet.measurement.api.v2alpha.ListRequisitionsRequestKt.filter
import org.wfanet.measurement.api.v2alpha.Measurement
import org.wfanet.measurement.api.v2alpha.MeasurementSpec
import org.wfanet.measurement.api.v2alpha.Requisition
import org.wfanet.measurement.api.v2alpha.RequisitionKt.refusal
import org.wfanet.measurement.api.v2alpha.RequisitionSpec
import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub
import org.wfanet.measurement.api.v2alpha.SignedMessage
import org.wfanet.measurement.api.v2alpha.fulfillDirectRequisitionRequest
import org.wfanet.measurement.api.v2alpha.getCertificateRequest
import org.wfanet.measurement.api.v2alpha.listRequisitionsRequest
import org.wfanet.measurement.api.v2alpha.refuseRequisitionRequest
import org.wfanet.measurement.api.v2alpha.unpack
import org.wfanet.measurement.common.crypto.PrivateKeyHandle
import org.wfanet.measurement.common.crypto.SigningKeyHandle
import org.wfanet.measurement.common.crypto.authorityKeyIdentifier
import org.wfanet.measurement.common.crypto.readCertificate
import org.wfanet.measurement.common.throttler.Throttler
import org.wfanet.measurement.consent.client.common.NonceMismatchException
import org.wfanet.measurement.consent.client.common.PublicKeyMismatchException
import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey
import org.wfanet.measurement.consent.client.dataprovider.decryptRequisitionSpec
import org.wfanet.measurement.consent.client.dataprovider.encryptResult
import org.wfanet.measurement.consent.client.dataprovider.signResult
import org.wfanet.measurement.consent.client.dataprovider.verifyMeasurementSpec
import org.wfanet.measurement.consent.client.dataprovider.verifyRequisitionSpec

data class DataProviderData(
/** The DataProvider's public API resource name. */
val name: String,
/** The DataProvider's display name. */
val displayName: String,
/** The DataProvider's decryption key. */
val privateEncryptionKey: PrivateKeyHandle,
/** The DataProvider's consent signaling signing key. */
val signingKeyHandle: SigningKeyHandle,
/** The CertificateKey to use for result signing. */
val certificateKey: DataProviderCertificateKey,
)

abstract class RequisitionFulfiller(
val dataProviderData: DataProviderData,
private val certificatesStub: CertificatesCoroutineStub,
private val requisitionsStub: RequisitionsCoroutineStub,
val throttler: Throttler,
private val trustedCertificates: Map<ByteString, X509Certificate>,
protected val measurementConsumerName: String,
) {
protected data class Specifications(
val measurementSpec: MeasurementSpec,
val requisitionSpec: RequisitionSpec,
)

protected class RequisitionRefusalException(
val justification: Requisition.Refusal.Justification,
message: String,
) : Exception(message)

protected class InvalidConsentSignalException(message: String? = null, cause: Throwable? = null) :
GeneralSecurityException(message, cause)

protected class InvalidSpecException(message: String, cause: Throwable? = null) :
Exception(message, cause)

/** A sequence of operations done in the simulator. */
abstract suspend fun run()

/** Executes the requisition fulfillment workflow. */
abstract suspend fun executeRequisitionFulfillingWorkflow()

protected fun verifySpecifications(
requisition: Requisition,
measurementConsumerCertificate: Certificate,
): Specifications {
val x509Certificate = readCertificate(measurementConsumerCertificate.x509Der)
// Look up the trusted issuer certificate for this MC certificate. Note that this doesn't
// confirm that this is the trusted issuer for the right MC. In a production environment,
// consider having a mapping of MC to root/CA cert.
val trustedIssuer =
trustedCertificates[checkNotNull(x509Certificate.authorityKeyIdentifier)]
?: throw InvalidConsentSignalException(
"Issuer of ${measurementConsumerCertificate.name} is not trusted"
)

try {
verifyMeasurementSpec(requisition.measurementSpec, x509Certificate, trustedIssuer)
} catch (e: CertPathValidatorException) {
throw InvalidConsentSignalException(
"Certificate path for ${measurementConsumerCertificate.name} is invalid",
e,
)
} catch (e: SignatureException) {
throw InvalidConsentSignalException("MeasurementSpec signature is invalid", e)
}

val measurementSpec: MeasurementSpec = requisition.measurementSpec.message.unpack()

val publicKey = requisition.dataProviderPublicKey.unpack(EncryptionPublicKey::class.java)!!
check(publicKey == dataProviderData.privateEncryptionKey.publicKey.toEncryptionPublicKey()) {
"Unable to decrypt for this public key"
}
val signedRequisitionSpec: SignedMessage =
try {
decryptRequisitionSpec(
requisition.encryptedRequisitionSpec,
dataProviderData.privateEncryptionKey,
)
} catch (e: GeneralSecurityException) {
throw InvalidConsentSignalException("RequisitionSpec decryption failed", e)
}
val requisitionSpec: RequisitionSpec = signedRequisitionSpec.unpack()

try {
verifyRequisitionSpec(
signedRequisitionSpec,
requisitionSpec,
measurementSpec,
x509Certificate,
trustedIssuer,
)
} catch (e: CertPathValidatorException) {
throw InvalidConsentSignalException(
"Certificate path for ${measurementConsumerCertificate.name} is invalid",
e,
)
} catch (e: SignatureException) {
throw InvalidConsentSignalException("RequisitionSpec signature is invalid", e)
} catch (e: NonceMismatchException) {
throw InvalidConsentSignalException(e.message, e)
} catch (e: PublicKeyMismatchException) {
throw InvalidConsentSignalException(e.message, e)
}

// TODO(@uakyol): Validate that collection interval is not outside of privacy landscape.

return Specifications(measurementSpec, requisitionSpec)
}

protected suspend fun getCertificate(resourceName: String): Certificate {
return try {
certificatesStub.getCertificate(getCertificateRequest { name = resourceName })
} catch (e: StatusException) {
throw Exception("Error fetching certificate $resourceName", e)
}
}

protected suspend fun refuseRequisition(
requisitionName: String,
justification: Requisition.Refusal.Justification,
message: String,
): Requisition {
try {
return requisitionsStub.refuseRequisition(
refuseRequisitionRequest {
name = requisitionName
refusal = refusal {
this.justification = justification
this.message = message
}
}
)
} catch (e: StatusException) {
throw Exception("Error refusing requisition $requisitionName", e)
}
}

protected suspend fun getRequisitions(): List<Requisition> {
val request = listRequisitionsRequest {
parent = dataProviderData.name
filter = filter {
states += Requisition.State.UNFULFILLED
measurementStates += Measurement.State.AWAITING_REQUISITION_FULFILLMENT
}
}

try {
return requisitionsStub.listRequisitions(request).requisitionsList
} catch (e: StatusException) {
throw Exception("Error listing requisitions", e)
}
}

protected suspend fun fulfillDirectMeasurement(
requisition: Requisition,
measurementSpec: MeasurementSpec,
nonce: Long,
measurementResult: Measurement.Result,
) {
logger.log(Level.INFO, "Direct MeasurementSpec:\n$measurementSpec")
logger.log(Level.INFO, "Direct MeasurementResult:\n$measurementResult")

DataProviderCertificateKey.fromName(requisition.dataProviderCertificate)
?: throw RequisitionRefusalException(
Requisition.Refusal.Justification.UNFULFILLABLE,
"Invalid data provider certificate",
)
val measurementEncryptionPublicKey: EncryptionPublicKey =
if (measurementSpec.hasMeasurementPublicKey()) {
measurementSpec.measurementPublicKey.unpack()
} else {
@Suppress("DEPRECATION") // Handle legacy resources.
EncryptionPublicKey.parseFrom(measurementSpec.serializedMeasurementPublicKey)
}
val signedResult: SignedMessage =
signResult(measurementResult, dataProviderData.signingKeyHandle)
val encryptedResult: EncryptedMessage =
encryptResult(signedResult, measurementEncryptionPublicKey)

try {
requisitionsStub.fulfillDirectRequisition(
fulfillDirectRequisitionRequest {
name = requisition.name
this.encryptedResult = encryptedResult
this.nonce = nonce
this.certificate = dataProviderData.certificateKey.toName()
}
)
} catch (e: StatusException) {
throw Exception("Error fulfilling direct requisition ${requisition.name}", e)
}
}

companion object {
val logger: Logger = Logger.getLogger(this::class.java.name)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.Synthetic
import org.wfanet.measurement.api.v2alpha.event_templates.testing.TestEvent
import org.wfanet.measurement.common.identity.withPrincipalName
import org.wfanet.measurement.common.throttler.MinimumIntervalThrottler
import org.wfanet.measurement.dataprovider.DataProviderData
import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.InMemoryBackingStore
import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.PrivacyBucketFilter
import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.PrivacyBudgetManager
import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.testing.TestPrivacyBucketMapper
import org.wfanet.measurement.loadtest.dataprovider.EdpData
import org.wfanet.measurement.loadtest.dataprovider.EdpSimulator
import org.wfanet.measurement.loadtest.dataprovider.SyntheticGeneratorEventQuery
import org.wfanet.measurement.loadtest.dataprovider.VidToIndexMapGenerator
Expand Down Expand Up @@ -145,10 +145,12 @@ class InProcessEdpSimulator(

suspend fun ensureEventGroup() = delegate.ensureEventGroup(EVENT_TEMPLATES, syntheticDataSpec)

/** Builds a [EdpData] object for the Edp with a certain [displayName] and [resourceName]. */
/**
* Builds a [DataProviderData] object for the Edp with a certain [displayName] and [resourceName].
*/
@Blocking
private fun createEdpData(displayName: String, resourceName: String) =
EdpData(
DataProviderData(
name = resourceName,
displayName = displayName,
certificateKey = certificateKey,
Expand Down
Loading

0 comments on commit 4f6254c

Please sign in to comment.