diff --git a/coordinator/conf/config.json b/coordinator/conf/config.json index de8944a7b9..fb738926c6 100644 --- a/coordinator/conf/config.json +++ b/coordinator/conf/config.json @@ -2,6 +2,7 @@ "prover_manager": { "provers_per_session": 1, "session_attempts": 5, + "external_prover_threshold": 32, "bundle_collection_time_sec": 180, "batch_collection_time_sec": 180, "chunk_collection_time_sec": 180, diff --git a/coordinator/internal/config/config.go b/coordinator/internal/config/config.go index dbdaa40b02..52a9158d0b 100644 --- a/coordinator/internal/config/config.go +++ b/coordinator/internal/config/config.go @@ -16,6 +16,8 @@ type ProverManager struct { // Number of attempts that a session can be retried if previous attempts failed. // Currently we only consider proving timeout as failure here. SessionAttempts uint8 `json:"session_attempts"` + // Threshold for activating the external prover based on unassigned task count. + ExternalProverThreshold int64 `json:"external_prover_threshold"` // Zk verifier config. Verifier *VerifierConfig `json:"verifier"` // BatchCollectionTimeSec batch Proof collection time (in seconds). diff --git a/coordinator/internal/logic/provertask/batch_prover_task.go b/coordinator/internal/logic/provertask/batch_prover_task.go index 7a472c4baf..08fa468fb0 100644 --- a/coordinator/internal/logic/provertask/batch_prover_task.go +++ b/coordinator/internal/logic/provertask/batch_prover_task.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "time" "github.com/gin-gonic/gin" @@ -63,29 +64,59 @@ func (bp *BatchProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato maxActiveAttempts := bp.cfg.ProverManager.ProversPerSession maxTotalAttempts := bp.cfg.ProverManager.SessionAttempts + if strings.HasPrefix(taskCtx.ProverName, ExternalProverNamePrefix) { + unassignedBatchCount, getCountError := bp.batchOrm.GetUnassignedBatchCount(ctx.Copy(), maxActiveAttempts, maxTotalAttempts) + if getCountError != nil { + log.Error("failed to get unassigned batch proving tasks count", "height", getTaskParameter.ProverHeight, "err", err) + return nil, ErrCoordinatorInternalFailure + } + // Assign external prover if unassigned task number exceeds threshold + if unassignedBatchCount < bp.cfg.ProverManager.ExternalProverThreshold { + return nil, nil + } + } + var batchTask *orm.Batch for i := 0; i < 5; i++ { var getTaskError error var tmpBatchTask *orm.Batch - tmpBatchTask, getTaskError = bp.batchOrm.GetAssignedBatch(ctx.Copy(), maxActiveAttempts, maxTotalAttempts) + var assignedOffset, unassignedOffset = 0, 0 + tmpAssignedBatchTasks, getTaskError := bp.batchOrm.GetAssignedBatches(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, 50) if getTaskError != nil { log.Error("failed to get assigned batch proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError) return nil, ErrCoordinatorInternalFailure } - // Why here need get again? In order to support a task can assign to multiple prover, need also assign `ProvingTaskAssigned` - // batch to prover. But use `proving_status in (1, 2)` will not use the postgres index. So need split the sql. - if tmpBatchTask == nil { - tmpBatchTask, getTaskError = bp.batchOrm.GetUnassignedBatch(ctx.Copy(), maxActiveAttempts, maxTotalAttempts) + // chunk to prover. But use `proving_status in (1, 2)` will not use the postgres index. So need split the sql. + tmpUnassignedBatchTask, getTaskError := bp.batchOrm.GetUnassignedBatches(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, 50) + if getTaskError != nil { + log.Error("failed to get unassigned batch proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError) + return nil, ErrCoordinatorInternalFailure + } + for { + tmpBatchTask = nil + if assignedOffset < len(tmpAssignedBatchTasks) { + tmpBatchTask = tmpAssignedBatchTasks[assignedOffset] + assignedOffset++ + } else if unassignedOffset < len(tmpUnassignedBatchTask) { + tmpBatchTask = tmpUnassignedBatchTask[unassignedOffset] + unassignedOffset++ + } + + if tmpBatchTask == nil { + log.Debug("get empty batch", "height", getTaskParameter.ProverHeight) + return nil, nil + } + + // Don't dispatch the same failing job to the same prover + proverTask, getTaskError := bp.proverTaskOrm.GetTaskOfProver(ctx.Copy(), message.ProofTypeBatch, tmpBatchTask.Hash, taskCtx.PublicKey, taskCtx.ProverVersion) if getTaskError != nil { - log.Error("failed to get unassigned batch proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError) + log.Error("failed to get prover task of prover", "proof_type", message.ProofTypeBatch.String(), "taskID", tmpBatchTask.Hash, "key", taskCtx.PublicKey, "Prover_version", taskCtx.ProverVersion, "error", getTaskError) return nil, ErrCoordinatorInternalFailure } - } - - if tmpBatchTask == nil { - log.Debug("get empty batch", "height", getTaskParameter.ProverHeight) - return nil, nil + if proverTask == nil || types.ProverProveStatus(proverTask.ProvingStatus) != types.ProverProofInvalid { + break + } } rowsAffected, updateAttemptsErr := bp.batchOrm.UpdateBatchAttempts(ctx.Copy(), tmpBatchTask.Index, tmpBatchTask.ActiveAttempts, tmpBatchTask.TotalAttempts) diff --git a/coordinator/internal/logic/provertask/bundle_prover_task.go b/coordinator/internal/logic/provertask/bundle_prover_task.go index c8901d2d4e..a13823e10d 100644 --- a/coordinator/internal/logic/provertask/bundle_prover_task.go +++ b/coordinator/internal/logic/provertask/bundle_prover_task.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "time" "github.com/gin-gonic/gin" @@ -63,29 +64,59 @@ func (bp *BundleProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinat maxActiveAttempts := bp.cfg.ProverManager.ProversPerSession maxTotalAttempts := bp.cfg.ProverManager.SessionAttempts + if strings.HasPrefix(taskCtx.ProverName, ExternalProverNamePrefix) { + unassignedBundleCount, getCountError := bp.bundleOrm.GetUnassignedBundleCount(ctx.Copy(), maxActiveAttempts, maxTotalAttempts) + if getCountError != nil { + log.Error("failed to get unassigned batch proving tasks count", "height", getTaskParameter.ProverHeight, "err", err) + return nil, ErrCoordinatorInternalFailure + } + // Assign external prover if unassigned task number exceeds threshold + if unassignedBundleCount < bp.cfg.ProverManager.ExternalProverThreshold { + return nil, nil + } + } + var bundleTask *orm.Bundle for i := 0; i < 5; i++ { var getTaskError error var tmpBundleTask *orm.Bundle - tmpBundleTask, getTaskError = bp.bundleOrm.GetAssignedBundle(ctx.Copy(), maxActiveAttempts, maxTotalAttempts) + var assignedOffset, unassignedOffset = 0, 0 + tmpAssignedBundleTasks, getTaskError := bp.bundleOrm.GetAssignedBundles(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, 50) if getTaskError != nil { - log.Error("failed to get assigned bundle proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError) + log.Error("failed to get assigned batch proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError) return nil, ErrCoordinatorInternalFailure } - // Why here need get again? In order to support a task can assign to multiple prover, need also assign `ProvingTaskAssigned` - // bundle to prover. But use `proving_status in (1, 2)` will not use the postgres index. So need split the sql. - if tmpBundleTask == nil { - tmpBundleTask, getTaskError = bp.bundleOrm.GetUnassignedBundle(ctx.Copy(), maxActiveAttempts, maxTotalAttempts) + // chunk to prover. But use `proving_status in (1, 2)` will not use the postgres index. So need split the sql. + tmpUnassignedBundleTask, getTaskError := bp.bundleOrm.GetUnassignedBundles(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, 50) + if getTaskError != nil { + log.Error("failed to get unassigned batch proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError) + return nil, ErrCoordinatorInternalFailure + } + for { + tmpBundleTask = nil + if assignedOffset < len(tmpAssignedBundleTasks) { + tmpBundleTask = tmpAssignedBundleTasks[assignedOffset] + assignedOffset++ + } else if unassignedOffset < len(tmpUnassignedBundleTask) { + tmpBundleTask = tmpUnassignedBundleTask[unassignedOffset] + unassignedOffset++ + } + + if tmpBundleTask == nil { + log.Debug("get empty bundle", "height", getTaskParameter.ProverHeight) + return nil, nil + } + + // Don't dispatch the same failing job to the same prover + proverTask, getTaskError := bp.proverTaskOrm.GetTaskOfProver(ctx.Copy(), message.ProofTypeBatch, tmpBundleTask.Hash, taskCtx.PublicKey, taskCtx.ProverVersion) if getTaskError != nil { - log.Error("failed to get unassigned bundle proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError) + log.Error("failed to get prover task of prover", "proof_type", message.ProofTypeBatch.String(), "taskID", tmpBundleTask.Hash, "key", taskCtx.PublicKey, "Prover_version", taskCtx.ProverVersion, "error", getTaskError) return nil, ErrCoordinatorInternalFailure } - } - - if tmpBundleTask == nil { - log.Debug("get empty bundle", "height", getTaskParameter.ProverHeight) - return nil, nil + if proverTask == nil || types.ProverProveStatus(proverTask.ProvingStatus) != types.ProverProofInvalid { + break + } } rowsAffected, updateAttemptsErr := bp.bundleOrm.UpdateBundleAttempts(ctx.Copy(), tmpBundleTask.Hash, tmpBundleTask.ActiveAttempts, tmpBundleTask.TotalAttempts) diff --git a/coordinator/internal/logic/provertask/chunk_prover_task.go b/coordinator/internal/logic/provertask/chunk_prover_task.go index 56e82a91d3..5aa396e2f2 100644 --- a/coordinator/internal/logic/provertask/chunk_prover_task.go +++ b/coordinator/internal/logic/provertask/chunk_prover_task.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "time" "github.com/gin-gonic/gin" @@ -61,29 +62,59 @@ func (cp *ChunkProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato maxActiveAttempts := cp.cfg.ProverManager.ProversPerSession maxTotalAttempts := cp.cfg.ProverManager.SessionAttempts + if strings.HasPrefix(taskCtx.ProverName, ExternalProverNamePrefix) { + unassignedChunkCount, getCountError := cp.chunkOrm.GetUnassignedChunkCount(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, getTaskParameter.ProverHeight) + if getCountError != nil { + log.Error("failed to get unassigned chunk proving tasks count", "height", getTaskParameter.ProverHeight, "err", err) + return nil, ErrCoordinatorInternalFailure + } + // Assign external prover if unassigned task number exceeds threshold + if unassignedChunkCount < cp.cfg.ProverManager.ExternalProverThreshold { + return nil, nil + } + } + var chunkTask *orm.Chunk for i := 0; i < 5; i++ { var getTaskError error var tmpChunkTask *orm.Chunk - tmpChunkTask, getTaskError = cp.chunkOrm.GetAssignedChunk(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, getTaskParameter.ProverHeight) + var assignedOffset, unassignedOffset = 0, 0 + tmpAssignedChunkTasks, getTaskError := cp.chunkOrm.GetAssignedChunks(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, getTaskParameter.ProverHeight, 50) if getTaskError != nil { log.Error("failed to get assigned chunk proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError) return nil, ErrCoordinatorInternalFailure } - // Why here need get again? In order to support a task can assign to multiple prover, need also assign `ProvingTaskAssigned` // chunk to prover. But use `proving_status in (1, 2)` will not use the postgres index. So need split the sql. - if tmpChunkTask == nil { - tmpChunkTask, getTaskError = cp.chunkOrm.GetUnassignedChunk(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, getTaskParameter.ProverHeight) + tmpUnassignedChunkTask, getTaskError := cp.chunkOrm.GetUnassignedChunk(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, getTaskParameter.ProverHeight, 50) + if getTaskError != nil { + log.Error("failed to get unassigned chunk proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError) + return nil, ErrCoordinatorInternalFailure + } + for { + tmpChunkTask = nil + if assignedOffset < len(tmpAssignedChunkTasks) { + tmpChunkTask = tmpAssignedChunkTasks[assignedOffset] + assignedOffset++ + } else if unassignedOffset < len(tmpUnassignedChunkTask) { + tmpChunkTask = tmpUnassignedChunkTask[unassignedOffset] + unassignedOffset++ + } + + if tmpChunkTask == nil { + log.Debug("get empty chunk", "height", getTaskParameter.ProverHeight) + return nil, nil + } + + // Don't dispatch the same failing job to the same prover + proverTask, getTaskError := cp.proverTaskOrm.GetTaskOfProver(ctx.Copy(), message.ProofTypeChunk, tmpChunkTask.Hash, taskCtx.PublicKey, taskCtx.ProverVersion) if getTaskError != nil { - log.Error("failed to get unassigned chunk proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError) + log.Error("failed to get prover task of prover", "proof_type", message.ProofTypeChunk.String(), "taskID", tmpChunkTask.Hash, "key", taskCtx.PublicKey, "Prover_version", taskCtx.ProverVersion, "error", getTaskError) return nil, ErrCoordinatorInternalFailure } - } - - if tmpChunkTask == nil { - log.Debug("get empty chunk", "height", getTaskParameter.ProverHeight) - return nil, nil + if proverTask == nil || types.ProverProveStatus(proverTask.ProvingStatus) != types.ProverProofInvalid { + break + } } rowsAffected, updateAttemptsErr := cp.chunkOrm.UpdateChunkAttempts(ctx.Copy(), tmpChunkTask.Index, tmpChunkTask.ActiveAttempts, tmpChunkTask.TotalAttempts) diff --git a/coordinator/internal/logic/provertask/prover_task.go b/coordinator/internal/logic/provertask/prover_task.go index 14568e120e..507f3cad0d 100644 --- a/coordinator/internal/logic/provertask/prover_task.go +++ b/coordinator/internal/logic/provertask/prover_task.go @@ -27,6 +27,11 @@ var ( getTaskCounterVec *prometheus.CounterVec = nil ) +var ( + // ExternalProverNamePrefix prefix of prover name + ExternalProverNamePrefix = "external" +) + // ProverTask the interface of a collector who send data to prover type ProverTask interface { Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) diff --git a/coordinator/internal/orm/batch.go b/coordinator/internal/orm/batch.go index 3dd8412e58..b3907c9ec8 100644 --- a/coordinator/internal/orm/batch.go +++ b/coordinator/internal/orm/batch.go @@ -78,38 +78,48 @@ func (*Batch) TableName() string { return "batch" } -// GetUnassignedBatch retrieves unassigned batch based on the specified limit. +// GetUnassignedBatches retrieves unassigned batches based on the specified limit. // The returned batches are sorted in ascending order by their index. -func (o *Batch) GetUnassignedBatch(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8) (*Batch, error) { - var batch Batch +func (o *Batch) GetUnassignedBatches(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8, limit uint64) ([]*Batch, error) { + var batch []*Batch db := o.db.WithContext(ctx) - sql := fmt.Sprintf("SELECT * FROM batch WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND chunk_proofs_status = %d AND batch.deleted_at IS NULL ORDER BY batch.index LIMIT 1;", - int(types.ProvingTaskUnassigned), maxTotalAttempts, maxActiveAttempts, int(types.ChunkProofsStatusReady)) + sql := fmt.Sprintf("SELECT * FROM batch WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND chunk_proofs_status = %d AND batch.deleted_at IS NULL ORDER BY batch.index LIMIT %d;", + int(types.ProvingTaskUnassigned), maxTotalAttempts, maxActiveAttempts, int(types.ChunkProofsStatusReady), limit) err := db.Raw(sql).Scan(&batch).Error if err != nil { - return nil, fmt.Errorf("Batch.GetUnassignedBatch error: %w", err) + return nil, fmt.Errorf("Batch.GetUnassignedBatches error: %w", err) } - if batch.Hash == "" { - return nil, nil + return batch, nil +} + +// GetUnassignedBatchCount retrieves unassigned batch count based on the specified limit. +func (o *Batch) GetUnassignedBatchCount(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8) (int64, error) { + var count int64 + db := o.db.WithContext(ctx) + db = db.Model(&Batch{}) + db = db.Where("proving_status = ?", int(types.ProvingTaskUnassigned)) + db = db.Where("total_attempts < ?", maxTotalAttempts) + db = db.Where("active_attempts < ?", maxActiveAttempts) + db = db.Where("chunk_proofs_status = ?", int(types.ChunkProofsStatusReady)) + db = db.Where("batch.deleted_at IS NULL") + if err := db.Count(&count).Error; err != nil { + return 0, fmt.Errorf("Batch.GetUnassignedBatchCount error: %w", err) } - return &batch, nil + return count, nil } -// GetAssignedBatch retrieves assigned batch based on the specified limit. +// GetAssignedBatches retrieves assigned batches based on the specified limit. // The returned batches are sorted in ascending order by their index. -func (o *Batch) GetAssignedBatch(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8) (*Batch, error) { - var batch Batch +func (o *Batch) GetAssignedBatches(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8, limit uint64) ([]*Batch, error) { + var batch []*Batch db := o.db.WithContext(ctx) - sql := fmt.Sprintf("SELECT * FROM batch WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND chunk_proofs_status = %d AND batch.deleted_at IS NULL ORDER BY batch.index LIMIT 1;", - int(types.ProvingTaskAssigned), maxTotalAttempts, maxActiveAttempts, int(types.ChunkProofsStatusReady)) + sql := fmt.Sprintf("SELECT * FROM batch WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND chunk_proofs_status = %d AND batch.deleted_at IS NULL ORDER BY batch.index LIMIT %d;", + int(types.ProvingTaskAssigned), maxTotalAttempts, maxActiveAttempts, int(types.ChunkProofsStatusReady), limit) err := db.Raw(sql).Scan(&batch).Error if err != nil { - return nil, fmt.Errorf("Batch.GetAssignedBatch error: %w", err) - } - if batch.Hash == "" { - return nil, nil + return nil, fmt.Errorf("Batch.GetAssignedBatches error: %w", err) } - return &batch, nil + return batch, nil } // GetUnassignedAndChunksUnreadyBatches get the batches which is unassigned and chunks is not ready @@ -132,19 +142,6 @@ func (o *Batch) GetUnassignedAndChunksUnreadyBatches(ctx context.Context, offset return batches, nil } -// GetAssignedBatches retrieves all batches whose proving_status is either types.ProvingTaskAssigned. -func (o *Batch) GetAssignedBatches(ctx context.Context) ([]*Batch, error) { - db := o.db.WithContext(ctx) - db = db.Model(&Batch{}) - db = db.Where("proving_status = ?", int(types.ProvingTaskAssigned)) - - var assignedBatches []*Batch - if err := db.Find(&assignedBatches).Error; err != nil { - return nil, fmt.Errorf("Batch.GetAssignedBatches error: %w", err) - } - return assignedBatches, nil -} - // GetProvingStatusByHash retrieves the proving status of a batch given its hash. func (o *Batch) GetProvingStatusByHash(ctx context.Context, hash string) (types.ProvingStatus, error) { db := o.db.WithContext(ctx) diff --git a/coordinator/internal/orm/bundle.go b/coordinator/internal/orm/bundle.go index 5deeff1114..0bf6efa5b8 100644 --- a/coordinator/internal/orm/bundle.go +++ b/coordinator/internal/orm/bundle.go @@ -54,38 +54,47 @@ func (*Bundle) TableName() string { return "bundle" } -// GetUnassignedBundle retrieves unassigned bundle based on the specified limit. +// GetUnassignedBundles retrieves unassigned bundle based on the specified limit. // The returned batch sorts in ascending order by their index. -func (o *Bundle) GetUnassignedBundle(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8) (*Bundle, error) { - var bundle Bundle +func (o *Bundle) GetUnassignedBundles(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8, limit uint64) ([]*Bundle, error) { + var bundle []*Bundle db := o.db.WithContext(ctx) - sql := fmt.Sprintf("SELECT * FROM bundle WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND batch_proofs_status = %d AND bundle.deleted_at IS NULL ORDER BY bundle.index LIMIT 1;", - int(types.ProvingTaskUnassigned), maxTotalAttempts, maxActiveAttempts, int(types.BatchProofsStatusReady)) + sql := fmt.Sprintf("SELECT * FROM bundle WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND batch_proofs_status = %d AND bundle.deleted_at IS NULL ORDER BY bundle.index LIMIT %d;", + int(types.ProvingTaskUnassigned), maxTotalAttempts, maxActiveAttempts, int(types.BatchProofsStatusReady), limit) err := db.Raw(sql).Scan(&bundle).Error if err != nil { - return nil, fmt.Errorf("Batch.GetUnassignedBundle error: %w", err) - } - if bundle.StartBatchHash == "" || bundle.EndBatchHash == "" { - return nil, nil + return nil, fmt.Errorf("Batch.GetUnassignedBundles error: %w", err) } - return &bundle, nil + return bundle, nil +} + +// GetUnassignedBundleCount retrieves unassigned bundle count based on the specified limit. +func (o *Bundle) GetUnassignedBundleCount(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8) (int64, error) { + var count int64 + db := o.db.WithContext(ctx) + db = db.Model(&Bundle{}) + db = db.Where("proving_status = ?", int(types.ProvingTaskUnassigned)) + db = db.Where("total_attempts < ?", maxTotalAttempts) + db = db.Where("active_attempts < ?", maxActiveAttempts) + db = db.Where("bundle.deleted_at IS NULL") + if err := db.Count(&count).Error; err != nil { + return 0, fmt.Errorf("Bundle.GetUnassignedBundleCount error: %w", err) + } + return count, nil } -// GetAssignedBundle retrieves assigned bundle based on the specified limit. +// GetAssignedBundles retrieves assigned bundles based on the specified limit. // The returned bundle sorts in ascending order by their index. -func (o *Bundle) GetAssignedBundle(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8) (*Bundle, error) { - var bundle Bundle +func (o *Bundle) GetAssignedBundles(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8, limit uint64) ([]*Bundle, error) { + var bundle []*Bundle db := o.db.WithContext(ctx) - sql := fmt.Sprintf("SELECT * FROM bundle WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND batch_proofs_status = %d AND bundle.deleted_at IS NULL ORDER BY bundle.index LIMIT 1;", - int(types.ProvingTaskAssigned), maxTotalAttempts, maxActiveAttempts, int(types.BatchProofsStatusReady)) + sql := fmt.Sprintf("SELECT * FROM bundle WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND batch_proofs_status = %d AND bundle.deleted_at IS NULL ORDER BY bundle.index LIMIT %d;", + int(types.ProvingTaskAssigned), maxTotalAttempts, maxActiveAttempts, int(types.BatchProofsStatusReady), limit) err := db.Raw(sql).Scan(&bundle).Error if err != nil { - return nil, fmt.Errorf("Bundle.GetAssignedBatch error: %w", err) + return nil, fmt.Errorf("Bundle.GetAssignedBundles error: %w", err) } - if bundle.StartBatchHash == "" || bundle.EndBatchHash == "" { - return nil, nil - } - return &bundle, nil + return bundle, nil } // GetProvingStatusByHash retrieves the proving status of a bundle given its hash. diff --git a/coordinator/internal/orm/chunk.go b/coordinator/internal/orm/chunk.go index 3f1964c400..ce73f3cbb9 100644 --- a/coordinator/internal/orm/chunk.go +++ b/coordinator/internal/orm/chunk.go @@ -73,36 +73,46 @@ func (*Chunk) TableName() string { // GetUnassignedChunk retrieves unassigned chunk based on the specified limit. // The returned chunks are sorted in ascending order by their index. -func (o *Chunk) GetUnassignedChunk(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8, height uint64) (*Chunk, error) { - var chunk Chunk +func (o *Chunk) GetUnassignedChunk(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8, height, limit uint64) ([]*Chunk, error) { + var chunks []*Chunk db := o.db.WithContext(ctx) - sql := fmt.Sprintf("SELECT * FROM chunk WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND end_block_number <= %d AND chunk.deleted_at IS NULL ORDER BY chunk.index LIMIT 1;", - int(types.ProvingTaskUnassigned), maxTotalAttempts, maxActiveAttempts, height) - err := db.Raw(sql).Scan(&chunk).Error + sql := fmt.Sprintf("SELECT * FROM chunk WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND end_block_number <= %d AND chunk.deleted_at IS NULL ORDER BY chunk.index LIMIT %d;", + int(types.ProvingTaskUnassigned), maxTotalAttempts, maxActiveAttempts, height, limit) + err := db.Raw(sql).Scan(&chunks).Error if err != nil { return nil, fmt.Errorf("Chunk.GetUnassignedChunk error: %w", err) } - if chunk.Hash == "" { - return nil, nil + return chunks, nil +} + +// GetUnassignedChunkCount retrieves unassigned chunk count based on the specified limit. +func (o *Chunk) GetUnassignedChunkCount(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8, height uint64) (int64, error) { + var count int64 + db := o.db.WithContext(ctx) + db = db.Model(&Chunk{}) + db = db.Where("proving_status = ?", int(types.ProvingTaskUnassigned)) + db = db.Where("total_attempts < ?", maxTotalAttempts) + db = db.Where("active_attempts < ?", maxActiveAttempts) + db = db.Where("end_block_number <= ?", height) + db = db.Where("chunk.deleted_at IS NULL") + if err := db.Count(&count).Error; err != nil { + return 0, fmt.Errorf("Chunk.GetUnassignedChunkCount error: %w", err) } - return &chunk, nil + return count, nil } -// GetAssignedChunk retrieves assigned chunk based on the specified limit. +// GetAssignedChunks retrieves assigned chunks based on the specified limit. // The returned chunks are sorted in ascending order by their index. -func (o *Chunk) GetAssignedChunk(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8, height uint64) (*Chunk, error) { - var chunk Chunk +func (o *Chunk) GetAssignedChunks(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8, height uint64, limit uint64) ([]*Chunk, error) { + var chunks []*Chunk db := o.db.WithContext(ctx) - sql := fmt.Sprintf("SELECT * FROM chunk WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND end_block_number <= %d AND chunk.deleted_at IS NULL ORDER BY chunk.index LIMIT 1;", - int(types.ProvingTaskAssigned), maxTotalAttempts, maxActiveAttempts, height) - err := db.Raw(sql).Scan(&chunk).Error + sql := fmt.Sprintf("SELECT * FROM chunk WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND end_block_number <= %d AND chunk.deleted_at IS NULL ORDER BY chunk.index LIMIT %d;", + int(types.ProvingTaskAssigned), maxTotalAttempts, maxActiveAttempts, height, limit) + err := db.Raw(sql).Scan(&chunks).Error if err != nil { - return nil, fmt.Errorf("Chunk.GetAssignedChunk error: %w", err) + return nil, fmt.Errorf("Chunk.GetAssignedChunks error: %w", err) } - if chunk.Hash == "" { - return nil, nil - } - return &chunk, nil + return chunks, nil } // GetChunksByBatchHash retrieves the chunks associated with a specific batch hash. diff --git a/coordinator/internal/orm/prover_task.go b/coordinator/internal/orm/prover_task.go index 00d8b36167..53aea7a383 100644 --- a/coordinator/internal/orm/prover_task.go +++ b/coordinator/internal/orm/prover_task.go @@ -148,6 +148,24 @@ func (o *ProverTask) GetAssignedTaskOfOtherProvers(ctx context.Context, taskType return proverTasks, nil } +// GetTaskOfOtherProvers get the chunk/batch task of prover +func (o *ProverTask) GetTaskOfProver(ctx context.Context, taskType message.ProofType, taskID, proverPublicKey, proverVersion string) (*ProverTask, error) { + db := o.db.WithContext(ctx) + db = db.Model(&ProverTask{}) + db = db.Where("task_type", int(taskType)) + db = db.Where("task_id", taskID) + db = db.Where("prover_public_key", proverPublicKey) + db = db.Where("prover_version", proverVersion) + db = db.Limit(1) + + var proverTask ProverTask + err := db.Find(&proverTask).Error + if err != nil { + return nil, fmt.Errorf("ProverTask.GetTaskOfProver error: %w, taskID: %v, publicKey:%s", err, taskID, proverPublicKey) + } + return &proverTask, nil +} + // GetProvingStatusByTaskID retrieves the proving status of a prover task func (o *ProverTask) GetProvingStatusByTaskID(ctx context.Context, taskType message.ProofType, taskID string) (types.ProverProveStatus, error) { db := o.db.WithContext(ctx) diff --git a/prover/config.json b/prover/config.json index 0a816360d5..7247ce49b1 100644 --- a/prover/config.json +++ b/prover/config.json @@ -3,7 +3,7 @@ "keystore_path": "keystore.json", "keystore_password": "prover-pwd", "db_path": "unique-db-path-for-prover-1", - "prover_type": 2, + "prover_types": [2], "low_version_circuit": { "hard_fork_name": "bernoulli", "params_path": "params", diff --git a/prover/src/config.rs b/prover/src/config.rs index 4e3c1f2ccc..05a57ddae2 100644 --- a/prover/src/config.rs +++ b/prover/src/config.rs @@ -30,7 +30,7 @@ pub struct Config { pub keystore_path: String, pub keystore_password: String, pub db_path: String, - pub prover_type: ProverType, + pub prover_types: Vec, pub low_version_circuit: CircuitConfig, pub high_version_circuit: CircuitConfig, pub coordinator: CoordinatorConfig, diff --git a/prover/src/coordinator_client.rs b/prover/src/coordinator_client.rs index 46067d7ccf..43ea4f8dfb 100644 --- a/prover/src/coordinator_client.rs +++ b/prover/src/coordinator_client.rs @@ -73,7 +73,7 @@ impl<'a> CoordinatorClient<'a> { challenge: token.clone(), prover_name: self.config.prover_name.clone(), prover_version: crate::version::get_version(), - prover_types: vec![self.config.prover_type], + prover_types: self.config.prover_types.clone(), vks: self.vks.clone(), }; diff --git a/prover/src/main.rs b/prover/src/main.rs index 75553187a9..aafdd36add 100644 --- a/prover/src/main.rs +++ b/prover/src/main.rs @@ -66,7 +66,7 @@ fn start() -> Result<()> { log::info!( "prover start successfully. name: {}, type: {:?}, publickey: {}, version: {}", config.prover_name, - config.prover_type, + config.prover_types, prover.get_public_key(), version::get_version(), ); diff --git a/prover/src/prover.rs b/prover/src/prover.rs index 7de83906e0..4c190cb190 100644 --- a/prover/src/prover.rs +++ b/prover/src/prover.rs @@ -8,8 +8,8 @@ use crate::{ coordinator_client::{listener::Listener, types::*, CoordinatorClient}, geth_client::GethClient, key_signer::KeySigner, - types::{ProofFailureType, ProofStatus, ProverType}, - utils::get_task_types, + types::{ProofFailureType, ProofStatus, ProverType, TaskType}, + utils::{get_prover_type, get_task_types}, zk_circuits_handler::{CircuitsHandler, CircuitsHandlerProvider}, }; @@ -25,11 +25,14 @@ pub struct Prover<'a> { impl<'a> Prover<'a> { pub fn new(config: &'a Config, coordinator_listener: Box) -> Result { - let prover_type = config.prover_type; let keystore_path = &config.keystore_path; let keystore_password = &config.keystore_password; - let geth_client = if config.prover_type == ProverType::Chunk { + let geth_client = if config + .prover_types + .iter() + .any(|element| *element == ProverType::Chunk) + { Some(Rc::new(RefCell::new( GethClient::new( &config.prover_name, @@ -41,10 +44,10 @@ impl<'a> Prover<'a> { None }; - let provider = CircuitsHandlerProvider::new(prover_type, config, geth_client.clone()) + let provider = CircuitsHandlerProvider::new(config, geth_client.clone()) .context("failed to create circuits handler provider")?; - let vks = provider.init_vks(prover_type, config, geth_client.clone()); + let vks = provider.init_vks(config.prover_types.clone(), config, geth_client.clone()); let key_signer = Rc::new(KeySigner::new(keystore_path, keystore_password)?); let coordinator_client = @@ -68,12 +71,27 @@ impl<'a> Prover<'a> { pub fn fetch_task(&self) -> Result { log::info!("[prover] start to fetch_task"); + + let task_types: Vec = + self.config + .prover_types + .iter() + .fold(Vec::new(), |mut acc, prover_type| { + acc.extend(get_task_types(*prover_type)); + acc + }); + let mut req = GetTaskRequest { - task_types: get_task_types(self.config.prover_type), + task_types, prover_height: None, }; - if self.config.prover_type == ProverType::Chunk { + if self + .config + .prover_types + .iter() + .any(|element| *element == ProverType::Chunk) + { let latest_block_number = self.get_latest_block_number_value()?; if let Some(v) = latest_block_number { if v.as_u64() == 0 { @@ -96,11 +114,17 @@ impl<'a> Prover<'a> { } pub fn prove_task(&self, task: &Task) -> Result { + let prover_type = match get_prover_type(task.task_type) { + Some(pt) => Ok(pt), + None => { + bail!("unsupported prover_type.") + } + }?; log::info!("[prover] start to prove_task, task id: {}", task.id); let handler: Rc> = self .circuits_handler_provider .borrow_mut() - .get_circuits_handler(&task.hard_fork_name) + .get_circuits_handler(&task.hard_fork_name, prover_type) .context("failed to get circuit handler")?; self.do_prove(task, handler) } diff --git a/prover/src/utils.rs b/prover/src/utils.rs index 18be4ac7a1..0347554adc 100644 --- a/prover/src/utils.rs +++ b/prover/src/utils.rs @@ -24,9 +24,31 @@ pub fn log_init(log_file: Option) { }); } +// pub fn get_task_types(prover_types: Vec) -> Vec { +// prover_types.into_iter().fold(Vec::new(), |mut acc, prover_type| { +// match prover_type { +// ProverType::Chunk => acc.push(TaskType::Chunk), +// ProverType::Batch => { +// acc.push(TaskType::Batch); +// acc.push(TaskType::Bundle); +// } +// } +// acc +// }) +// } + pub fn get_task_types(prover_type: ProverType) -> Vec { match prover_type { ProverType::Chunk => vec![TaskType::Chunk], ProverType::Batch => vec![TaskType::Batch, TaskType::Bundle], } } + +pub fn get_prover_type(task_type: TaskType) -> Option { + match task_type { + TaskType::Undefined => None, + TaskType::Chunk => Some(ProverType::Chunk), + TaskType::Batch => Some(ProverType::Batch), + TaskType::Bundle => Some(ProverType::Batch), + } +} diff --git a/prover/src/zk_circuits_handler.rs b/prover/src/zk_circuits_handler.rs index d1a8eb38c5..8956b0019c 100644 --- a/prover/src/zk_circuits_handler.rs +++ b/prover/src/zk_circuits_handler.rs @@ -34,21 +34,17 @@ type CircuitsHandlerBuilder = fn( ) -> Result>; pub struct CircuitsHandlerProvider<'a> { - prover_type: ProverType, config: &'a Config, geth_client: Option>>, circuits_handler_builder_map: HashMap, current_fork_name: Option, + current_prover_type: Option, current_circuit: Option>>, } impl<'a> CircuitsHandlerProvider<'a> { - pub fn new( - prover_type: ProverType, - config: &'a Config, - geth_client: Option>>, - ) -> Result { + pub fn new(config: &'a Config, geth_client: Option>>) -> Result { let mut m: HashMap = HashMap::new(); fn handler_builder( @@ -99,11 +95,11 @@ impl<'a> CircuitsHandlerProvider<'a> { ); let provider = CircuitsHandlerProvider { - prover_type, config, geth_client, circuits_handler_builder_map: m, current_fork_name: None, + current_prover_type: None, current_circuit: None, }; @@ -113,6 +109,7 @@ impl<'a> CircuitsHandlerProvider<'a> { pub fn get_circuits_handler( &mut self, hard_fork_name: &String, + prover_type: ProverType, ) -> Result>> { match &self.current_fork_name { Some(fork_name) if fork_name == hard_fork_name => { @@ -129,9 +126,10 @@ impl<'a> CircuitsHandlerProvider<'a> { ); if let Some(builder) = self.circuits_handler_builder_map.get(hard_fork_name) { log::info!("building circuits handler for {hard_fork_name}"); - let handler = builder(self.prover_type, self.config, self.geth_client.clone()) + let handler = builder(prover_type, self.config, self.geth_client.clone()) .expect("failed to build circuits handler"); self.current_fork_name = Some(hard_fork_name.clone()); + self.current_prover_type = Some(prover_type); let rc_handler = Rc::new(handler); self.current_circuit = Some(rc_handler.clone()); Ok(rc_handler) @@ -144,31 +142,37 @@ impl<'a> CircuitsHandlerProvider<'a> { pub fn init_vks( &self, - prover_type: ProverType, + prover_types: Vec, config: &'a Config, geth_client: Option>>, ) -> Vec { self.circuits_handler_builder_map .iter() .flat_map(|(hard_fork_name, build)| { - let handler = build(prover_type, config, geth_client.clone()) - .expect("failed to build circuits handler"); - - get_task_types(prover_type) - .into_iter() - .map(|task_type| { - let vk = handler - .get_vk(task_type) - .map_or("".to_string(), utils::encode_vk); - log::info!( - "vk for {hard_fork_name}, is {vk}, task_type: {:?}", - task_type - ); - vk + let geth_client_clone = geth_client.clone(); + prover_types + .iter() + .flat_map(move |prover_type| { + let handler = build(*prover_type, config, geth_client_clone.clone()) + .expect("failed to build circuits handler"); + + get_task_types(*prover_type) + .into_iter() + .map(move |task_type| { + let vk = handler + .get_vk(task_type) + .map_or("".to_string(), utils::encode_vk); + log::info!( + "vk for {hard_fork_name}, is {vk}, task_type: {:?}", + task_type + ); + vk + }) + .filter(|vk| !vk.is_empty()) + .collect::>() }) - .filter(|vk| !vk.is_empty()) .collect::>() }) - .collect::>() + .collect() } }