Skip to content

Commit

Permalink
Add prioritizedStages for duchy claimTask (#1673)
Browse files Browse the repository at this point in the history
Issue
#1637
  • Loading branch information
renjiezh committed Jul 15, 2024
1 parent cd6952d commit 09b69b2
Show file tree
Hide file tree
Showing 19 changed files with 389 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ abstract class MillBase(
) {
abstract val endingStage: ComputationStage

abstract val prioritizedStagesToClaim: List<ComputationStage>

private val meter: Meter = openTelemetry.getMeter(this::class.qualifiedName!!)

protected val cryptoWallClockDurationHistogram: DoubleHistogram =
Expand Down Expand Up @@ -170,6 +172,7 @@ abstract class MillBase(
computationType = this@MillBase.computationType
owner = millId
lockDuration = workLockDuration.toProtoDuration()
prioritizedStages += prioritizedStagesToClaim
}
val claimWorkResponse =
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ import org.wfanet.measurement.duchy.service.internal.computations.outputPathList
import org.wfanet.measurement.duchy.service.system.v1alpha.advanceComputationHeader
import org.wfanet.measurement.duchy.toProtocolStage
import org.wfanet.measurement.internal.duchy.ComputationDetails.CompletedReason
import org.wfanet.measurement.internal.duchy.ComputationStage
import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt.ComputationStatsCoroutineStub
import org.wfanet.measurement.internal.duchy.ComputationToken
import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType
import org.wfanet.measurement.internal.duchy.ElGamalPublicKey
import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsRequest
import org.wfanet.measurement.internal.duchy.computationStage
import org.wfanet.measurement.internal.duchy.config.RoleInComputation
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.AGGREGATOR
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.NON_AGGREGATOR
Expand Down Expand Up @@ -144,10 +144,9 @@ class ReachFrequencyLiquidLegionsV2Mill(
maximumAttempts,
clock,
) {
override val endingStage: ComputationStage =
ComputationStage.newBuilder()
.apply { liquidLegionsSketchAggregationV2 = Stage.COMPLETE }
.build()
override val endingStage = Stage.COMPLETE.toProtocolStage()

override val prioritizedStagesToClaim = listOf(Stage.INITIALIZATION_PHASE.toProtocolStage())

private val actions =
mapOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.wfanet.measurement.internal.duchy.ComputationToken
import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType
import org.wfanet.measurement.internal.duchy.ElGamalPublicKey
import org.wfanet.measurement.internal.duchy.computationDetails
import org.wfanet.measurement.internal.duchy.computationStage
import org.wfanet.measurement.internal.duchy.config.RoleInComputation
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.AGGREGATOR
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.NON_AGGREGATOR
Expand Down Expand Up @@ -134,6 +135,8 @@ class ReachOnlyLiquidLegionsV2Mill(
) {
override val endingStage = Stage.COMPLETE.toProtocolStage()

override val prioritizedStagesToClaim = listOf(Stage.INITIALIZATION_PHASE.toProtocolStage())

private val actions =
mapOf(
Pair(Stage.INITIALIZATION_PHASE, AGGREGATOR) to ::initializationPhase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class HonestMajorityShareShuffleMill(

override val endingStage = Stage.COMPLETE.toProtocolStage()

override val prioritizedStagesToClaim = listOf(Stage.INITIALIZED.toProtocolStage())

private val actions =
mapOf(
Pair(Stage.INITIALIZED, FIRST_NON_AGGREGATOR) to ::initializationPhase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,16 @@ interface ComputationsDatabaseTransactor<ProtocolT, StageT, StageDetailsT, Compu
*
* @param protocol The protocol of the task to claim
* @param ownerId The identifier of the worker process that will own the lock.
* @param lockDuration The time-to-live of the lock ownership.
* @param prioritizedStages Stages that have the priority to be claimed.
* @return global computation id of work that was claimed. When null, no work was claimed.
*/
suspend fun claimTask(protocol: ProtocolT, ownerId: String, lockDuration: Duration): String?
suspend fun claimTask(
protocol: ProtocolT,
ownerId: String,
lockDuration: Duration,
prioritizedStages: List<StageT> = listOf(),
): String?

/**
* Transitions a computation to a new stage.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,9 @@ private constructor(
protocol: ComputationType,
ownerId: String,
lockDuration: Duration,
prioritizedStages: List<ComputationStage>,
): String? {
val claimed =
val editTokens =
tokens.values
.asSequence()
.filter { it.globalComputationId !in claimedComputationIds }
Expand All @@ -385,7 +386,13 @@ private constructor(
globalId = it.globalComputationId,
)
}
.firstOrNull() ?: return null
val prioritizedComputations = editTokens.filter { it.stage in prioritizedStages }
val claimed =
if (!prioritizedComputations.none()) {
prioritizedComputations.first()
} else {
editTokens.firstOrNull() ?: return null
}

updateToken(claimed) { existing ->
claimedComputationIds.add(existing.globalComputationId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class PostgresComputationsService(
try {
ClaimWork(
request.computationType,
request.prioritizedStagesList,
request.owner,
lockDuration,
clock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ kt_jvm_library(
"@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import org.wfanet.measurement.common.db.r2dbc.DatabaseClient
import org.wfanet.measurement.common.db.r2dbc.ReadContext
import org.wfanet.measurement.common.db.r2dbc.ResultRow
import org.wfanet.measurement.common.db.r2dbc.boundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement
import org.wfanet.measurement.common.grpc.grpcRequire
import org.wfanet.measurement.common.toProtoTime
import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnumHelper
Expand Down Expand Up @@ -349,10 +351,10 @@ class ComputationReader(
readContext: ReadContext,
protocol: Long,
timestamp: Instant,
prioritizedStages: List<ComputationStage>,
): Flow<UnclaimedTaskQueryResult> {
val listUnclaimedTasksSql =
boundStatement(
"""
val baseSql =
"""
SELECT c.ComputationId, c.GlobalComputationId,
c.Protocol, c.ComputationStage, c.UpdateTime,
c.CreationTime, cs.NextAttempt
Expand All @@ -363,15 +365,45 @@ class ComputationReader(
WHERE c.Protocol = $1
AND c.LockExpirationTime IS NOT NULL
AND c.LockExpirationTime <= $2
ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
LIMIT 50;
"""
) {
bind("$1", protocol)
bind("$2", timestamp)
}

return readContext.executeQuery(listUnclaimedTasksSql).consume(::buildUnclaimedTaskQueryResult)
if (prioritizedStages.isEmpty()) {
val baseSqlWithOrder =
baseSql +
"""
ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
LIMIT 50;
"""
val statement =
boundStatement(baseSqlWithOrder) {
bind("$1", protocol)
bind("$2", timestamp)
}

return readContext.executeQuery(statement).consume(::buildUnclaimedTaskQueryResult)
} else {
val baseSqlWithOrder =
baseSql +
"""
ORDER BY
CASE WHEN c.ComputationStage
IN (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER}) THEN 0
ELSE 1 END ASC,
c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
LIMIT 50;
"""
val statement =
valuesListBoundStatement(valuesStartIndex = 2, paramCount = 1, baseSqlWithOrder) {
bind("$1", protocol)
bind("$2", timestamp)
for (stage in prioritizedStages) {
val longValue =
computationProtocolStagesEnumHelper.computationStageEnumToLongValues(stage).stage
addValuesBinding { bindValuesParam(0, longValue) }
}
}
return readContext.executeQuery(statement).consume(::buildUnclaimedTaskQueryResult)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.wfanet.measurement.duchy.db.computation.ComputationTypeEnumHelper
import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader
import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationStageAttemptReader
import org.wfanet.measurement.duchy.service.internal.ComputationNotFoundException
import org.wfanet.measurement.internal.duchy.ComputationStage
import org.wfanet.measurement.internal.duchy.ComputationStageAttemptDetails
import org.wfanet.measurement.internal.duchy.ComputationToken
import org.wfanet.measurement.internal.duchy.copy
Expand All @@ -48,6 +49,7 @@ import org.wfanet.measurement.internal.duchy.copy
*/
class ClaimWork<ProtocolT, StageT>(
private val protocol: ProtocolT,
private val prioritizedStages: List<ComputationStage>,
private val ownerId: String,
private val lockDuration: Duration,
private val clock: Clock,
Expand All @@ -63,6 +65,7 @@ class ClaimWork<ProtocolT, StageT>(
transactionContext,
protocolEnum,
clock.instant().truncatedTo(ChronoUnit.MICROS),
prioritizedStages,
)
// First the possible tasks to claim are selected from the computations table, then for each
// item in the list we try to claim the lock in a transaction which will only succeed if the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class GcpSpannerComputationsDatabaseTransactor<
protocol: ProtocolT,
ownerId: String,
lockDuration: Duration,
prioritizedStages: List<StageT>,
): String? {
/** Claim a specific task represented by the results of running the above sql. */
suspend fun claimSpecificTask(result: UnclaimedTaskQueryResult<StageT>): Boolean =
Expand All @@ -159,7 +160,9 @@ class GcpSpannerComputationsDatabaseTransactor<
}
return UnclaimedTasksQuery(
computationMutations.protocolEnumToLong(protocol),
prioritizedStages,
computationMutations::longValuesToComputationStageEnum,
computationMutations::computationStageEnumToLongValues,
clock.gcloudTimestamp(),
)
.execute(databaseClient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ import com.google.cloud.spanner.Statement
import com.google.cloud.spanner.Struct
import org.wfanet.measurement.duchy.db.computation.ComputationStageLongValues
import org.wfanet.measurement.duchy.deploy.gcloud.spanner.common.SqlBasedQuery
import org.wfanet.measurement.gcloud.spanner.appendClause
import org.wfanet.measurement.gcloud.spanner.statement

/** Queries for computations which may be claimed at a timestamp. */
class UnclaimedTasksQuery<StageT>(
val protocol: Long,
val parseStageEnum: (ComputationStageLongValues) -> StageT,
protocol: Long,
prioritizedStages: List<StageT>,
val longValuesToComputationStageEnum: (ComputationStageLongValues) -> StageT,
computationStageEnumToLongValues: (StageT) -> ComputationStageLongValues,
timestamp: Timestamp,
) : SqlBasedQuery<UnclaimedTaskQueryResult<StageT>> {
companion object {
Expand All @@ -44,25 +48,38 @@ class UnclaimedTasksQuery<StageT>(
AND c.LockExpirationTime IS NOT NULL
AND c.UpdateTime IS NOT NULL
AND c.LockExpirationTime <= @current_time
ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
LIMIT 50
"""
}

override val sql: Statement =
Statement.newBuilder(parameterizedQueryString)
.bind("current_time")
.to(timestamp)
.bind("protocol")
.to(protocol)
.build()
statement(parameterizedQueryString) {
bind("current_time").to(timestamp)
bind("protocol").to(protocol)
if (prioritizedStages.isEmpty()) {
appendClause("ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC")
} else {
appendClause(
"""
ORDER BY
CASE WHEN c.ComputationStage IN UNNEST (@prioritized_stages) THEN 0
ELSE 1 END ASC,
c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
"""
.trimIndent()
)
val prioritizedStageLongValues =
prioritizedStages.map(computationStageEnumToLongValues).map { it.stage }
bind("prioritized_stages").toInt64Array(prioritizedStageLongValues)
}
appendClause("LIMIT 50")
}

override fun asResult(struct: Struct): UnclaimedTaskQueryResult<StageT> =
UnclaimedTaskQueryResult(
computationId = struct.getLong("ComputationId"),
globalId = struct.getString("GlobalComputationId"),
computationStage =
parseStageEnum(
longValuesToComputationStageEnum(
ComputationStageLongValues(struct.getLong("Protocol"), struct.getLong("ComputationStage"))
),
creationTime = struct.getTimestamp("CreationTime"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ class ComputationsService(
val lockDuration =
if (request.hasLockDuration()) request.lockDuration.toDuration() else defaultLockDuration
val claimed =
computationsDatabase.claimTask(request.computationType, request.owner, lockDuration)
computationsDatabase.claimTask(
request.computationType,
request.owner,
lockDuration,
request.prioritizedStagesList,
)
return if (claimed != null) {
val token = computationsDatabase.readComputationToken(claimed)!!
sendStatusUpdateToKingdom(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,62 @@ abstract class ComputationsServiceTest<T : ComputationsCoroutineImplBase> {
assertThat(claimWorkResponse).isEqualTo(ClaimWorkResponse.getDefaultInstance())
}

@Test
fun `claimWork returns computation of prioritized stage`() = runBlocking {
var computation1 =
service
.createComputation(
DEFAULT_CREATE_COMPUTATION_REQUEST.copy {
globalComputationId = "1"
requisitions.clear()
}
)
.token

computation1 =
service
.claimWork(
DEFAULT_CLAIM_WORK_REQUEST.copy {
this.lockDuration = Durations.fromSeconds(1)
prioritizedStages += Stage.INITIALIZATION_PHASE.toProtocolStage()
}
)
.token

computation1 =
service
.advanceComputationStage(
advanceComputationStageRequest {
token = computation1
nextComputationStage = Stage.WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage()
// To simplify the test, assume WAIT_REQUISITIONS_AND_KEY_SET is claimable.
afterTransition = AfterTransition.ADD_UNCLAIMED_TO_QUEUE
}
)
.token

// create a new computation with prioritized stage.
val computation2 =
service
.createComputation(
DEFAULT_CREATE_COMPUTATION_REQUEST.copy {
globalComputationId = "2"
requisitions.clear()
}
)
.token

val response =
service.claimWork(
DEFAULT_CLAIM_WORK_REQUEST.copy {
this.lockDuration = Durations.fromSeconds(1)
prioritizedStages += Stage.INITIALIZATION_PHASE.toProtocolStage()
}
)

assertThat(response.token.globalComputationId).isEqualTo(computation2.globalComputationId)
}

@Test
fun `getComputationIds returns empty response when no matching computations`() = runBlocking {
service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST)
Expand Down
Loading

0 comments on commit 09b69b2

Please sign in to comment.