diff --git a/prover/protocol/distributed/common.go b/prover/protocol/distributed/common.go index 0c458fd7bbe..9dfbdb19881 100644 --- a/prover/protocol/distributed/common.go +++ b/prover/protocol/distributed/common.go @@ -34,19 +34,19 @@ func ReplaceExternalCoins(initialComp, moduleComp *wizard.CompiledIOP, expr *sym } } -// GetFreshModuleComp creates a [wizard.CompiledIOP] object including only the columns relevant to the module. -// It also contains the prover steps for assigning the module column +// GetFreshModuleComp returns a [wizard.DefineFunc] that creates +// a [wizard.CompiledIOP] object including only the columns +// relevant to the module. It also contains the prover steps +// for assigning the module column func GetFreshModuleComp( initialComp *wizard.CompiledIOP, disc ModuleDiscoverer, - initialProver wizard.ProverStep, moduleName ModuleName, ) *wizard.CompiledIOP { var ( // initialize the moduleComp - moduleComp = wizard.NewCompiledIOP() - initialRunTime = wizard.RunProver(initialComp, initialProver) + moduleComp = wizard.NewCompiledIOP() ) for round := 0; round < initialComp.NumRounds(); round++ { @@ -65,14 +65,12 @@ func GetFreshModuleComp( // create a new moduleProver moduleProver := moduleProver{ - cols: columnsInRound, - initRun: initialRunTime, - round: round, + cols: columnsInRound, + round: round, } // register Prover action for the module to assign columns per round moduleComp.RegisterProverAction(round, moduleProver) - } return moduleComp @@ -83,17 +81,19 @@ type moduleProver struct { round int // columns for a specific round cols []ifaces.Column - // runtime of the initial Prover that is parent to the module. - initRun *wizard.ProverRuntime } // It implements [wizard.ProverAction] for the module prover. func (p moduleProver) Run(run *wizard.ProverRuntime) { + + if run.ParentRuntime == nil { + utils.Panic("invalid call: the runtime does not have a [ParentRuntime]") + } + for _, col := range p.cols { // get the witness from the initialProver - colWitness := p.initRun.GetColumn(col.GetColID()) + colWitness := run.ParentRuntime.GetColumn(col.GetColID()) // assign it in the module in the round col was declared run.AssignColumn(col.GetColID(), colWitness, col.Round()) } - } diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion_test.go b/prover/protocol/distributed/compiler/inclusion/inclusion_test.go index f2227d51323..7d19a963a5b 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion_test.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" logderiv "github.com/consensys/linea-monorepo/prover/protocol/compiler/logderivativesum" "github.com/consensys/linea-monorepo/prover/protocol/distributed" "github.com/consensys/linea-monorepo/prover/protocol/distributed/compiler/inclusion" - md "github.com/consensys/linea-monorepo/prover/protocol/distributed/module_discoverer" + md "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/consensys/linea-monorepo/prover/symbolic" @@ -50,22 +51,35 @@ func TestDistributedLogDerivSum(t *testing.T) { disc.Analyze(initialComp) // distribute the columns among modules; this includes also multiplicity columns - moduleComp0 := distributed.GetFreshModuleComp(initialComp, disc, prover, "module0") - moduleComp1 := distributed.GetFreshModuleComp(initialComp, disc, prover, "module1") + moduleComp0 := distributed.GetFreshModuleComp(initialComp, disc, "module0") + moduleComp1 := distributed.GetFreshModuleComp(initialComp, disc, "module1") // distribute the query LogDerivativeSum among modules. inclusion.DistributeLogDerivativeSum(initialComp, moduleComp0, "module0", disc) inclusion.DistributeLogDerivativeSum(initialComp, moduleComp1, "module1", disc) - // Compile and prove for module0 + // This compiles the log-derivative queries into global/local queries. logderiv.CompileLogDerivSum(moduleComp0) - proof0 := wizard.Prove(moduleComp0, func(run *wizard.ProverRuntime) {}) + logderiv.CompileLogDerivSum(moduleComp1) + + // This adds a dummy compilation step to control that all passes + dummy.CompileAtProverLvl(moduleComp0) + dummy.CompileAtProverLvl(moduleComp1) + + // run the initial runtime + initialRuntime := wizard.RunProver(initialComp, prover) + + // Compile and prove for module0 + proof0 := wizard.Prove(moduleComp0, func(run *wizard.ProverRuntime) { + run.ParentRuntime = initialRuntime + }) valid := wizard.Verify(moduleComp0, proof0) require.NoError(t, valid) // Compile and prove for module1 - logderiv.CompileLogDerivSum(moduleComp1) - proof1 := wizard.Prove(moduleComp1, func(run *wizard.ProverRuntime) {}) + proof1 := wizard.Prove(moduleComp1, func(run *wizard.ProverRuntime) { + run.ParentRuntime = initialRuntime + }) valid1 := wizard.Verify(moduleComp1, proof1) require.NoError(t, valid1) } diff --git a/prover/protocol/distributed/compiler/permutation/permutation.go b/prover/protocol/distributed/compiler/permutation/permutation.go index cceddcb828c..18ee8effca9 100644 --- a/prover/protocol/distributed/compiler/permutation/permutation.go +++ b/prover/protocol/distributed/compiler/permutation/permutation.go @@ -76,6 +76,7 @@ func NewPermutationIntoGrandProductCtx( for round := 0; round < numRounds; round++ { queries := initialComp.QueriesNoParams.AllKeysAt(round) for queryInRound, qName := range queries { + // Skip if it was already compiled if initialComp.QueriesNoParams.IsIgnored(qName) { continue @@ -99,6 +100,7 @@ func NewPermutationIntoGrandProductCtx( } } } + // We register the grand product query in round one because // alphas, betas, and the query param are assigned in round one p.Query = moduleComp.InsertGrandProduct(p.LastRoundPerm+1, qId, p.GdProdInputs) diff --git a/prover/protocol/distributed/compiler/permutation/permutation_test.go b/prover/protocol/distributed/compiler/permutation/permutation_test.go index 1fe81863ede..75d966b3e23 100644 --- a/prover/protocol/distributed/compiler/permutation/permutation_test.go +++ b/prover/protocol/distributed/compiler/permutation/permutation_test.go @@ -4,9 +4,9 @@ import ( "testing" "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" - "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" dist_permutation "github.com/consensys/linea-monorepo/prover/protocol/distributed/compiler/permutation" - modulediscoverer "github.com/consensys/linea-monorepo/prover/protocol/distributed/module_discoverer" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/stretchr/testify/require" @@ -15,9 +15,7 @@ import ( func TestPermutation(t *testing.T) { var ( - moduleAName = "MODULE_A" - // Initialise the period separating module discoverer - disc = modulediscoverer.PeriodSeperatingModuleDiscoverer{} + moduleAName = "moduleA" ) testcases := []struct { @@ -28,13 +26,13 @@ func TestPermutation(t *testing.T) { Name: "single-column-no-fragment", DefineFunc: func(builder *wizard.Builder) { a := []ifaces.Column{ - builder.RegisterCommit("MODULE_A.A0", 4), + builder.RegisterCommit("moduleA.A0", 4), } b := []ifaces.Column{ - builder.RegisterCommit("MODULE_B.B0", 4), + builder.RegisterCommit("moduleB.B0", 4), } c := []ifaces.Column{ - builder.RegisterCommit("MODULE_C.C0", 4), + builder.RegisterCommit("moduleC.C0", 4), } _ = builder.CompiledIOP.InsertPermutation(0, "P_MOD_A_MOD_B", a, b) _ = builder.CompiledIOP.InsertPermutation(0, "P_MOD_C_MOD_A", c, a) @@ -45,19 +43,19 @@ func TestPermutation(t *testing.T) { Name: "multi-column-no-fragment", DefineFunc: func(builder *wizard.Builder) { a := []ifaces.Column{ - builder.RegisterCommit("MODULE_A.A0", 4), - builder.RegisterCommit("MODULE_A.A1", 4), - builder.RegisterCommit("MODULE_A.A2", 4), + builder.RegisterCommit("moduleA.A0", 4), + builder.RegisterCommit("moduleA.A1", 4), + builder.RegisterCommit("moduleA.A2", 4), } b := []ifaces.Column{ - builder.RegisterCommit("MODULE_B.B0", 4), - builder.RegisterCommit("MODULE_B.B1", 4), - builder.RegisterCommit("MODULE_B.B2", 4), + builder.RegisterCommit("moduleB.B0", 4), + builder.RegisterCommit("moduleB.B1", 4), + builder.RegisterCommit("moduleB.B2", 4), } c := []ifaces.Column{ - builder.RegisterCommit("MODULE_C.C0", 4), - builder.RegisterCommit("MODULE_C.C1", 4), - builder.RegisterCommit("MODULE_C.C2", 4), + builder.RegisterCommit("moduleC.C0", 4), + builder.RegisterCommit("moduleC.C1", 4), + builder.RegisterCommit("moduleC.C2", 4), } _ = builder.CompiledIOP.InsertPermutation(0, "P_MOD_A_MOD_B", a, b) _ = builder.CompiledIOP.InsertPermutation(0, "P_MOD_C_MOD_A", c, a) @@ -69,38 +67,38 @@ func TestPermutation(t *testing.T) { DefineFunc: func(builder *wizard.Builder) { a := [][]ifaces.Column{ { - builder.RegisterCommit("MODULE_A.A00", 4), - builder.RegisterCommit("MODULE_A.A10", 4), - builder.RegisterCommit("MODULE_A.A20", 4), + builder.RegisterCommit("moduleA.A00", 4), + builder.RegisterCommit("moduleA.A10", 4), + builder.RegisterCommit("moduleA.A20", 4), }, { - builder.RegisterCommit("MODULE_A.A01", 4), - builder.RegisterCommit("MODULE_A.A11", 4), - builder.RegisterCommit("MODULE_A.A21", 4), + builder.RegisterCommit("moduleA.A01", 4), + builder.RegisterCommit("moduleA.A11", 4), + builder.RegisterCommit("moduleA.A21", 4), }, } b := [][]ifaces.Column{ { - builder.RegisterCommit("MODULE_B.B00", 4), - builder.RegisterCommit("MODULE_B.B10", 4), - builder.RegisterCommit("MODULE_B.B20", 4), + builder.RegisterCommit("moduleB.B00", 4), + builder.RegisterCommit("moduleB.B10", 4), + builder.RegisterCommit("moduleB.B20", 4), }, { - builder.RegisterCommit("MODULE_B.B01", 4), - builder.RegisterCommit("MODULE_B.B11", 4), - builder.RegisterCommit("MODULE_B.B21", 4), + builder.RegisterCommit("moduleB.B01", 4), + builder.RegisterCommit("moduleB.B11", 4), + builder.RegisterCommit("moduleB.B21", 4), }, } c := [][]ifaces.Column{ { - builder.RegisterCommit("MODULE_C.C00", 4), - builder.RegisterCommit("MODULE_C.C10", 4), - builder.RegisterCommit("MODULE_C.C20", 4), + builder.RegisterCommit("moduleC.C00", 4), + builder.RegisterCommit("moduleC.C10", 4), + builder.RegisterCommit("moduleC.C20", 4), }, { - builder.RegisterCommit("MODULE_C.C01", 4), - builder.RegisterCommit("MODULE_C.C11", 4), - builder.RegisterCommit("MODULE_C.C21", 4), + builder.RegisterCommit("moduleC.C01", 4), + builder.RegisterCommit("moduleC.C11", 4), + builder.RegisterCommit("moduleC.C21", 4), }, } _ = builder.CompiledIOP.InsertFragmentedPermutation(0, "P_MOD_A_MOD_B", a, b) @@ -114,55 +112,36 @@ func TestPermutation(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { - initialComp := wizard.Compile(tc.DefineFunc) - - disc.Analyze(initialComp) - - moduleAComp := wizard.Compile(func(build *wizard.Builder) { - - for _, colName := range initialComp.Columns.AllKeys() { - - col := initialComp.Columns.GetHandle(colName) - if !disc.ColumnIsInModule(col, moduleAName) { - continue - } - - build.RegisterCommit(col.GetColID(), col.Size()) - } - }, dummy.CompileAtProverLvl) - - var ( - _ = dist_permutation.NewPermutationIntoGrandProductCtx( - dist_permutation.Settings{TargetModuleName: moduleAName}, - initialComp, moduleAComp, &disc, - ) - initialRun *wizard.ProverRuntime - ) - + // This function assigns the initial module and is aimed at working + // for all test-case. initialProve := func(run *wizard.ProverRuntime) { for _, colName := range run.Spec.Columns.AllKeys() { run.AssignColumn(colName, smartvectors.ForTest(1, 2, 3, 4)) } - - initialRun = run } - _ = wizard.Prove(initialComp, initialProve) + // initialComp is defined according to the define function provided by the + // test-case. + initialComp := wizard.Compile(tc.DefineFunc) + + disc := namebaseddiscoverer.PeriodSeperatingModuleDiscoverer{} + disc.Analyze(initialComp) - moduleAProve := func(run *wizard.ProverRuntime) { - for _, colName := range initialComp.Columns.AllKeys() { + // This declares a compiled IOP with only the columns of the module A + moduleAComp := distributed.GetFreshModuleComp(initialComp, &disc, moduleAName) - col := initialComp.Columns.GetHandle(colName) - if !disc.ColumnIsInModule(col, moduleAName) { - continue - } + // This distributes the permutation queries + dist_permutation.NewPermutationIntoGrandProductCtx( + dist_permutation.Settings{TargetModuleName: moduleAName}, + initialComp, moduleAComp, &disc, + ) - c := initialRun.GetColumn(colName) - run.AssignColumn(colName, c) - } - } + // This runs the initial prover + initialRuntime := wizard.RunProver(initialComp, initialProve) - proof := wizard.Prove(moduleAComp, moduleAProve) + proof := wizard.Prove(moduleAComp, func(run *wizard.ProverRuntime) { + run.ParentRuntime = initialRuntime + }) valid := wizard.Verify(moduleAComp, proof) require.NoError(t, valid) diff --git a/prover/protocol/distributed/compiler/permutation/settings.go b/prover/protocol/distributed/compiler/permutation/settings.go index fad99e6dbf4..59de8d1f3ba 100644 --- a/prover/protocol/distributed/compiler/permutation/settings.go +++ b/prover/protocol/distributed/compiler/permutation/settings.go @@ -1,8 +1,8 @@ package dist_permutation -import modulediscoverer "github.com/consensys/linea-monorepo/prover/protocol/distributed/module_discoverer" +import "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" type Settings struct { // Name of the target module - TargetModuleName modulediscoverer.ModuleName + TargetModuleName namebaseddiscoverer.ModuleName } diff --git a/prover/protocol/distributed/distributed.go b/prover/protocol/distributed/distributed.go index 5418e4e5924..65798fea9b6 100644 --- a/prover/protocol/distributed/distributed.go +++ b/prover/protocol/distributed/distributed.go @@ -31,7 +31,7 @@ type ModuleDiscoverer interface { // group best the columns into modules. Analyze(comp *wizard.CompiledIOP) NbModules() int - ModuleList(comp *wizard.CompiledIOP) []ModuleName + ModuleList() []ModuleName FindModule(col ifaces.Column) ModuleName // given a query and a module name it checks if the query is inside the module ExpressionIsInModule(*symbolic.Expression, ModuleName) bool @@ -50,7 +50,7 @@ func Distribute(initialWizard *wizard.CompiledIOP, disc ModuleDiscoverer, maxSeg // analyze the initialWizard to split it to modules. disc.Analyze(initialWizard) - moduleLs := disc.ModuleList(initialWizard) + moduleLs := disc.ModuleList() distModules := []DistributedModule{} for _, modName := range moduleLs { diff --git a/prover/protocol/distributed/module_discoverer/period_separating_module_discoverer.go b/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go similarity index 97% rename from prover/protocol/distributed/module_discoverer/period_separating_module_discoverer.go rename to prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go index 00e3856c534..f94e195f79e 100644 --- a/prover/protocol/distributed/module_discoverer/period_separating_module_discoverer.go +++ b/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go @@ -1,4 +1,4 @@ -package distributed +package namebaseddiscoverer import ( "strings" @@ -60,7 +60,7 @@ func (p *PeriodSeperatingModuleDiscoverer) NbModules() int { } // ModuleList returns the list of module names -func (p *PeriodSeperatingModuleDiscoverer) ModuleList(comp *wizard.CompiledIOP) []ModuleName { +func (p *PeriodSeperatingModuleDiscoverer) ModuleList() []ModuleName { moduleNames := make([]ModuleName, 0, len(p.modules)) for moduleName := range p.modules { moduleNames = append(moduleNames, moduleName) diff --git a/prover/protocol/wizard/prover.go b/prover/protocol/wizard/prover.go index c4aaeee2e5d..7f091f4bb19 100644 --- a/prover/protocol/wizard/prover.go +++ b/prover/protocol/wizard/prover.go @@ -120,6 +120,12 @@ type ProverRuntime struct { // lock is global lock so that the assignment maps are thread safes lock *sync.Mutex + + // ParentRuntime stores an external runtime that can be accessed by the + // prover steps to retrieve data from a parent runtime. This can be used + // in the distributed prover by the module runtimes to access the initial + // wizard runtime. + ParentRuntime *ProverRuntime } // Prove is the top-level function that runs the Prover on the user's side. It @@ -165,7 +171,6 @@ func Prove(c *CompiledIOP, highLevelprover ProverStep) Proof { return Proof{ Messages: messages, QueriesParams: runtime.QueriesParams, - RunTime: runtime, } } diff --git a/prover/protocol/wizard/verifier.go b/prover/protocol/wizard/verifier.go index 16832479842..2c5b8da5fa2 100644 --- a/prover/protocol/wizard/verifier.go +++ b/prover/protocol/wizard/verifier.go @@ -32,9 +32,6 @@ type Proof struct { // QueriesParams stores all the query parameters (i.e) the messages of the // oracle to the verifier. QueriesParams collection.Mapping[ifaces.QueryID, ifaces.QueryParams] - - // RunTime is the run time of the prover during the proof generation - RunTime *ProverRuntime } // VerifierStep specifies a single step of verifier for a single subprotocol.