Skip to content

Commit

Permalink
Add prioritizedStages for duchy claimTask
Browse files Browse the repository at this point in the history
  • Loading branch information
renjiezh committed Jun 26, 2024
1 parent 78d6488 commit dd3cb1c
Show file tree
Hide file tree
Showing 18 changed files with 344 additions and 519 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ abstract class MillBase(
) {
abstract val endingStage: ComputationStage

abstract val prioritizedStagesToClaim: List<ComputationStage>

private val meter: Meter = openTelemetry.getMeter(this::class.qualifiedName!!)

protected val cryptoWallClockDurationHistogram: DoubleHistogram =
Expand Down Expand Up @@ -158,6 +160,7 @@ abstract class MillBase(
computationType = this@MillBase.computationType
owner = millId
lockDuration = workLockDuration.toProtoDuration()
prioritizedStages += prioritizedStagesToClaim
}
val claimWorkResponse =
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ import org.wfanet.measurement.duchy.service.internal.computations.outputPathList
import org.wfanet.measurement.duchy.service.system.v1alpha.advanceComputationHeader
import org.wfanet.measurement.duchy.toProtocolStage
import org.wfanet.measurement.internal.duchy.ComputationDetails.CompletedReason
import org.wfanet.measurement.internal.duchy.ComputationStage
import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt.ComputationStatsCoroutineStub
import org.wfanet.measurement.internal.duchy.ComputationToken
import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType
import org.wfanet.measurement.internal.duchy.ElGamalPublicKey
import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsRequest
import org.wfanet.measurement.internal.duchy.computationStage
import org.wfanet.measurement.internal.duchy.config.RoleInComputation
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.AGGREGATOR
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.NON_AGGREGATOR
Expand Down Expand Up @@ -148,10 +148,9 @@ class ReachFrequencyLiquidLegionsV2Mill(
maximumAttempts,
clock,
) {
override val endingStage: ComputationStage =
ComputationStage.newBuilder()
.apply { liquidLegionsSketchAggregationV2 = Stage.COMPLETE }
.build()
override val endingStage = Stage.COMPLETE.toProtocolStage()

override val prioritizedStagesToClaim = listOf(Stage.INITIALIZATION_PHASE.toProtocolStage())

private val actions =
mapOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import org.wfanet.measurement.internal.duchy.ComputationToken
import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType
import org.wfanet.measurement.internal.duchy.ElGamalPublicKey
import org.wfanet.measurement.internal.duchy.computationDetails
import org.wfanet.measurement.internal.duchy.computationStage
import org.wfanet.measurement.internal.duchy.config.RoleInComputation
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.AGGREGATOR
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.NON_AGGREGATOR
Expand Down Expand Up @@ -139,6 +140,8 @@ class ReachOnlyLiquidLegionsV2Mill(
) {
override val endingStage = Stage.COMPLETE.toProtocolStage()

override val prioritizedStagesToClaim = listOf(Stage.INITIALIZATION_PHASE.toProtocolStage())

private val actions =
mapOf(
Pair(Stage.INITIALIZATION_PHASE, AGGREGATOR) to ::initializationPhase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class HonestMajorityShareShuffleMill(

override val endingStage = Stage.COMPLETE.toProtocolStage()

override val prioritizedStagesToClaim = listOf(Stage.INITIALIZED.toProtocolStage())

private val actions =
mapOf(
Pair(Stage.INITIALIZED, FIRST_NON_AGGREGATOR) to ::initializationPhase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ interface ComputationsDatabaseTransactor<ProtocolT, StageT, StageDetailsT, Compu
* @param ownerId The identifier of the worker process that will own the lock.
* @return global computation id of work that was claimed. When null, no work was claimed.
*/
suspend fun claimTask(protocol: ProtocolT, ownerId: String, lockDuration: Duration): String?
suspend fun claimTask(
protocol: ProtocolT,
ownerId: String,
lockDuration: Duration,
prioritizedStages: List<StageT> = listOf(),
): String?

/**
* Transitions a computation to a new stage.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,9 @@ private constructor(
protocol: ComputationType,
ownerId: String,
lockDuration: Duration,
prioritizedStages: List<ComputationStage>,
): String? {
val claimed =
val editTokens =
tokens.values
.asSequence()
.filter { it.globalComputationId !in claimedComputationIds }
Expand All @@ -385,7 +386,13 @@ private constructor(
globalId = it.globalComputationId,
)
}
.firstOrNull() ?: return null
val prioritizedComputations = editTokens.filter { it.stage in prioritizedStages }
val claimed =
if (!prioritizedComputations.none()) {
prioritizedComputations.first()
} else {
editTokens.firstOrNull() ?: return null
}

updateToken(claimed) { existing ->
claimedComputationIds.add(existing.globalComputationId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class PostgresComputationsService(
try {
ClaimWork(
request.computationType,
request.prioritizedStagesList,
request.owner,
lockDuration,
clock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,24 +349,49 @@ class ComputationReader(
readContext: ReadContext,
protocol: Long,
timestamp: Instant,
prioritizedStages: List<ComputationStage>,
): Flow<UnclaimedTaskQueryResult> {
val listUnclaimedTasksSql =
boundStatement(
"""
val baseSql =
"""
SELECT c.ComputationId, c.GlobalComputationId,
c.Protocol, c.ComputationStage, c.UpdateTime,
c.CreationTime, cs.NextAttempt
FROM Computations AS c
JOIN ComputationStages AS cs
ON c.ComputationId = cs.ComputationId
AND c.ComputationStage = cs.ComputationStage
WHERE c.Protocol = $1
WHERE c.Protocol = ${'$'}1
AND c.LockExpirationTime IS NOT NULL
AND c.LockExpirationTime <= $2
ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
LIMIT 50;
AND c.LockExpirationTime <= ${'$'}2
"""
) {

val baseSqlWithOrder =
if (prioritizedStages.isEmpty()) {
baseSql +
"""
ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
LIMIT 50;
"""
} else {
// Binding list of String into the IN clause does not work as expected with r2dbc library.
// Hence, manually joining targeting stages into a comma separated string and stub it into
// the
// query.
val stagesString =
prioritizedStages
.map { computationProtocolStagesEnumHelper.computationStageEnumToLongValues(it).stage }
.toList()
.joinToString(",")
baseSql +
"""
ORDER BY CASE WHEN c.ComputationStage IN ($stagesString) THEN 0 ELSE 1 END ASC,
c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
LIMIT 50;
"""
}

val listUnclaimedTasksSql =
boundStatement(baseSqlWithOrder) {
bind("$1", protocol)
bind("$2", timestamp)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.wfanet.measurement.duchy.db.computation.ComputationTypeEnumHelper
import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader
import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationStageAttemptReader
import org.wfanet.measurement.duchy.service.internal.ComputationNotFoundException
import org.wfanet.measurement.internal.duchy.ComputationStage
import org.wfanet.measurement.internal.duchy.ComputationStageAttemptDetails
import org.wfanet.measurement.internal.duchy.ComputationToken
import org.wfanet.measurement.internal.duchy.copy
Expand All @@ -48,6 +49,7 @@ import org.wfanet.measurement.internal.duchy.copy
*/
class ClaimWork<ProtocolT, StageT>(
private val protocol: ProtocolT,
private val prioritizedStages: List<ComputationStage>,
private val ownerId: String,
private val lockDuration: Duration,
private val clock: Clock,
Expand All @@ -63,6 +65,7 @@ class ClaimWork<ProtocolT, StageT>(
transactionContext,
protocolEnum,
clock.instant().truncatedTo(ChronoUnit.MICROS),
prioritizedStages,
)
// First the possible tasks to claim are selected from the computations table, then for each
// item in the list we try to claim the lock in a transaction which will only succeed if the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class GcpSpannerComputationsDatabaseTransactor<
protocol: ProtocolT,
ownerId: String,
lockDuration: Duration,
prioritizedStages: List<StageT>,
): String? {
/** Claim a specific task represented by the results of running the above sql. */
suspend fun claimSpecificTask(result: UnclaimedTaskQueryResult<StageT>): Boolean =
Expand All @@ -157,8 +158,11 @@ class GcpSpannerComputationsDatabaseTransactor<
lockDuration,
)
}
val prioritizedStageLongValues =
prioritizedStages.map(computationMutations::computationStageEnumToLongValues).map { it.stage }
return UnclaimedTasksQuery(
computationMutations.protocolEnumToLong(protocol),
prioritizedStageLongValues,
computationMutations::longValuesToComputationStageEnum,
clock.gcloudTimestamp(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ import com.google.cloud.spanner.Statement
import com.google.cloud.spanner.Struct
import org.wfanet.measurement.duchy.db.computation.ComputationStageLongValues
import org.wfanet.measurement.duchy.deploy.gcloud.spanner.common.SqlBasedQuery
import org.wfanet.measurement.gcloud.spanner.appendClause
import org.wfanet.measurement.gcloud.spanner.statement

/** Queries for computations which may be claimed at a timestamp. */
class UnclaimedTasksQuery<StageT>(
val protocol: Long,
protocol: Long,
prioritizedStageLongValues: List<Long>,
val parseStageEnum: (ComputationStageLongValues) -> StageT,
timestamp: Timestamp,
) : SqlBasedQuery<UnclaimedTaskQueryResult<StageT>> {
Expand All @@ -44,18 +47,27 @@ class UnclaimedTasksQuery<StageT>(
AND c.LockExpirationTime IS NOT NULL
AND c.UpdateTime IS NOT NULL
AND c.LockExpirationTime <= @current_time
ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
LIMIT 50
"""
}

override val sql: Statement =
Statement.newBuilder(parameterizedQueryString)
.bind("current_time")
.to(timestamp)
.bind("protocol")
.to(protocol)
.build()
statement(parameterizedQueryString) {
bind("current_time").to(timestamp)
bind("protocol").to(protocol)
if (prioritizedStageLongValues.isEmpty()) {
appendClause("ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC")
} else {
appendClause(
"""
ORDER BY CASE WHEN c.ComputationStage IN UNNEST (@prioritized_stages) THEN 0 ELSE 1 END ASC,
c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
"""
.trimIndent()
)
bind("prioritized_stages").toInt64Array(prioritizedStageLongValues)
}
appendClause("LIMIT 50")
}

override fun asResult(struct: Struct): UnclaimedTaskQueryResult<StageT> =
UnclaimedTaskQueryResult(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ class ComputationsService(
val lockDuration =
if (request.hasLockDuration()) request.lockDuration.toDuration() else defaultLockDuration
val claimed =
computationsDatabase.claimTask(request.computationType, request.owner, lockDuration)
computationsDatabase.claimTask(
request.computationType,
request.owner,
lockDuration,
request.prioritizedStagesList,
)
return if (claimed != null) {
val token = computationsDatabase.readComputationToken(claimed)!!
sendStatusUpdateToKingdom(
Expand Down
Loading

0 comments on commit dd3cb1c

Please sign in to comment.