Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions prover/protocol/distributed/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@ func ReplaceExternalCoins(initialComp, moduleComp *wizard.CompiledIOP, expr *sym
}
}

// GetFreshModuleComp creates a [wizard.CompiledIOP] object including only the columns relevant to the module.
// It also contains the prover steps for assigning the module column
// GetFreshModuleComp returns a [wizard.DefineFunc] that creates
// a [wizard.CompiledIOP] object including only the columns
// relevant to the module. It also contains the prover steps
// for assigning the module column
func GetFreshModuleComp(
initialComp *wizard.CompiledIOP,
disc ModuleDiscoverer,
initialProver wizard.ProverStep,
moduleName ModuleName,
) *wizard.CompiledIOP {

var (
// initialize the moduleComp
moduleComp = wizard.NewCompiledIOP()
initialRunTime = wizard.RunProver(initialComp, initialProver)
moduleComp = wizard.NewCompiledIOP()
)

for round := 0; round < initialComp.NumRounds(); round++ {
Expand All @@ -65,14 +65,12 @@ func GetFreshModuleComp(

// create a new moduleProver
moduleProver := moduleProver{
cols: columnsInRound,
initRun: initialRunTime,
round: round,
cols: columnsInRound,
round: round,
}

// register Prover action for the module to assign columns per round
moduleComp.RegisterProverAction(round, moduleProver)

}

return moduleComp
Expand All @@ -83,17 +81,19 @@ type moduleProver struct {
round int
// columns for a specific round
cols []ifaces.Column
// runtime of the initial Prover that is parent to the module.
initRun *wizard.ProverRuntime
}

// It implements [wizard.ProverAction] for the module prover.
func (p moduleProver) Run(run *wizard.ProverRuntime) {

if run.ParentRuntime == nil {
utils.Panic("invalid call: the runtime does not have a [ParentRuntime]")
}

for _, col := range p.cols {
// get the witness from the initialProver
colWitness := p.initRun.GetColumn(col.GetColID())
colWitness := run.ParentRuntime.GetColumn(col.GetColID())
// assign it in the module in the round col was declared
run.AssignColumn(col.GetColID(), colWitness, col.Round())
}

}
28 changes: 21 additions & 7 deletions prover/protocol/distributed/compiler/inclusion/inclusion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import (
"testing"

"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy"
logderiv "github.com/consensys/linea-monorepo/prover/protocol/compiler/logderivativesum"
"github.com/consensys/linea-monorepo/prover/protocol/distributed"
"github.com/consensys/linea-monorepo/prover/protocol/distributed/compiler/inclusion"
md "github.com/consensys/linea-monorepo/prover/protocol/distributed/module_discoverer"
md "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer"
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/consensys/linea-monorepo/prover/symbolic"
Expand Down Expand Up @@ -50,22 +51,35 @@ func TestDistributedLogDerivSum(t *testing.T) {
disc.Analyze(initialComp)

// distribute the columns among modules; this includes also multiplicity columns
moduleComp0 := distributed.GetFreshModuleComp(initialComp, disc, prover, "module0")
moduleComp1 := distributed.GetFreshModuleComp(initialComp, disc, prover, "module1")
moduleComp0 := distributed.GetFreshModuleComp(initialComp, disc, "module0")
moduleComp1 := distributed.GetFreshModuleComp(initialComp, disc, "module1")

// distribute the query LogDerivativeSum among modules.
inclusion.DistributeLogDerivativeSum(initialComp, moduleComp0, "module0", disc)
inclusion.DistributeLogDerivativeSum(initialComp, moduleComp1, "module1", disc)

// Compile and prove for module0
// This compiles the log-derivative queries into global/local queries.
logderiv.CompileLogDerivSum(moduleComp0)
proof0 := wizard.Prove(moduleComp0, func(run *wizard.ProverRuntime) {})
logderiv.CompileLogDerivSum(moduleComp1)

// This adds a dummy compilation step to control that all passes
dummy.CompileAtProverLvl(moduleComp0)
dummy.CompileAtProverLvl(moduleComp1)

// run the initial runtime
initialRuntime := wizard.RunProver(initialComp, prover)

// Compile and prove for module0
proof0 := wizard.Prove(moduleComp0, func(run *wizard.ProverRuntime) {
run.ParentRuntime = initialRuntime
})
valid := wizard.Verify(moduleComp0, proof0)
require.NoError(t, valid)

// Compile and prove for module1
logderiv.CompileLogDerivSum(moduleComp1)
proof1 := wizard.Prove(moduleComp1, func(run *wizard.ProverRuntime) {})
proof1 := wizard.Prove(moduleComp1, func(run *wizard.ProverRuntime) {
run.ParentRuntime = initialRuntime
})
valid1 := wizard.Verify(moduleComp1, proof1)
require.NoError(t, valid1)
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ func NewPermutationIntoGrandProductCtx(
for round := 0; round < numRounds; round++ {
queries := initialComp.QueriesNoParams.AllKeysAt(round)
for queryInRound, qName := range queries {

// Skip if it was already compiled
if initialComp.QueriesNoParams.IsIgnored(qName) {
continue
Expand All @@ -99,6 +100,7 @@ func NewPermutationIntoGrandProductCtx(
}
}
}

// We register the grand product query in round one because
// alphas, betas, and the query param are assigned in round one
p.Query = moduleComp.InsertGrandProduct(p.LastRoundPerm+1, qId, p.GdProdInputs)
Expand Down
127 changes: 53 additions & 74 deletions prover/protocol/distributed/compiler/permutation/permutation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import (
"testing"

"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy"
"github.com/consensys/linea-monorepo/prover/protocol/distributed"
dist_permutation "github.com/consensys/linea-monorepo/prover/protocol/distributed/compiler/permutation"
modulediscoverer "github.com/consensys/linea-monorepo/prover/protocol/distributed/module_discoverer"
"github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer"
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/stretchr/testify/require"
Expand All @@ -15,9 +15,7 @@ import (
func TestPermutation(t *testing.T) {

var (
moduleAName = "MODULE_A"
// Initialise the period separating module discoverer
disc = modulediscoverer.PeriodSeperatingModuleDiscoverer{}
moduleAName = "moduleA"
)

testcases := []struct {
Expand All @@ -28,13 +26,13 @@ func TestPermutation(t *testing.T) {
Name: "single-column-no-fragment",
DefineFunc: func(builder *wizard.Builder) {
a := []ifaces.Column{
builder.RegisterCommit("MODULE_A.A0", 4),
builder.RegisterCommit("moduleA.A0", 4),
}
b := []ifaces.Column{
builder.RegisterCommit("MODULE_B.B0", 4),
builder.RegisterCommit("moduleB.B0", 4),
}
c := []ifaces.Column{
builder.RegisterCommit("MODULE_C.C0", 4),
builder.RegisterCommit("moduleC.C0", 4),
}
_ = builder.CompiledIOP.InsertPermutation(0, "P_MOD_A_MOD_B", a, b)
_ = builder.CompiledIOP.InsertPermutation(0, "P_MOD_C_MOD_A", c, a)
Expand All @@ -45,19 +43,19 @@ func TestPermutation(t *testing.T) {
Name: "multi-column-no-fragment",
DefineFunc: func(builder *wizard.Builder) {
a := []ifaces.Column{
builder.RegisterCommit("MODULE_A.A0", 4),
builder.RegisterCommit("MODULE_A.A1", 4),
builder.RegisterCommit("MODULE_A.A2", 4),
builder.RegisterCommit("moduleA.A0", 4),
builder.RegisterCommit("moduleA.A1", 4),
builder.RegisterCommit("moduleA.A2", 4),
}
b := []ifaces.Column{
builder.RegisterCommit("MODULE_B.B0", 4),
builder.RegisterCommit("MODULE_B.B1", 4),
builder.RegisterCommit("MODULE_B.B2", 4),
builder.RegisterCommit("moduleB.B0", 4),
builder.RegisterCommit("moduleB.B1", 4),
builder.RegisterCommit("moduleB.B2", 4),
}
c := []ifaces.Column{
builder.RegisterCommit("MODULE_C.C0", 4),
builder.RegisterCommit("MODULE_C.C1", 4),
builder.RegisterCommit("MODULE_C.C2", 4),
builder.RegisterCommit("moduleC.C0", 4),
builder.RegisterCommit("moduleC.C1", 4),
builder.RegisterCommit("moduleC.C2", 4),
}
_ = builder.CompiledIOP.InsertPermutation(0, "P_MOD_A_MOD_B", a, b)
_ = builder.CompiledIOP.InsertPermutation(0, "P_MOD_C_MOD_A", c, a)
Expand All @@ -69,38 +67,38 @@ func TestPermutation(t *testing.T) {
DefineFunc: func(builder *wizard.Builder) {
a := [][]ifaces.Column{
{
builder.RegisterCommit("MODULE_A.A00", 4),
builder.RegisterCommit("MODULE_A.A10", 4),
builder.RegisterCommit("MODULE_A.A20", 4),
builder.RegisterCommit("moduleA.A00", 4),
builder.RegisterCommit("moduleA.A10", 4),
builder.RegisterCommit("moduleA.A20", 4),
},
{
builder.RegisterCommit("MODULE_A.A01", 4),
builder.RegisterCommit("MODULE_A.A11", 4),
builder.RegisterCommit("MODULE_A.A21", 4),
builder.RegisterCommit("moduleA.A01", 4),
builder.RegisterCommit("moduleA.A11", 4),
builder.RegisterCommit("moduleA.A21", 4),
},
}
b := [][]ifaces.Column{
{
builder.RegisterCommit("MODULE_B.B00", 4),
builder.RegisterCommit("MODULE_B.B10", 4),
builder.RegisterCommit("MODULE_B.B20", 4),
builder.RegisterCommit("moduleB.B00", 4),
builder.RegisterCommit("moduleB.B10", 4),
builder.RegisterCommit("moduleB.B20", 4),
},
{
builder.RegisterCommit("MODULE_B.B01", 4),
builder.RegisterCommit("MODULE_B.B11", 4),
builder.RegisterCommit("MODULE_B.B21", 4),
builder.RegisterCommit("moduleB.B01", 4),
builder.RegisterCommit("moduleB.B11", 4),
builder.RegisterCommit("moduleB.B21", 4),
},
}
c := [][]ifaces.Column{
{
builder.RegisterCommit("MODULE_C.C00", 4),
builder.RegisterCommit("MODULE_C.C10", 4),
builder.RegisterCommit("MODULE_C.C20", 4),
builder.RegisterCommit("moduleC.C00", 4),
builder.RegisterCommit("moduleC.C10", 4),
builder.RegisterCommit("moduleC.C20", 4),
},
{
builder.RegisterCommit("MODULE_C.C01", 4),
builder.RegisterCommit("MODULE_C.C11", 4),
builder.RegisterCommit("MODULE_C.C21", 4),
builder.RegisterCommit("moduleC.C01", 4),
builder.RegisterCommit("moduleC.C11", 4),
builder.RegisterCommit("moduleC.C21", 4),
},
}
_ = builder.CompiledIOP.InsertFragmentedPermutation(0, "P_MOD_A_MOD_B", a, b)
Expand All @@ -114,55 +112,36 @@ func TestPermutation(t *testing.T) {

t.Run(tc.Name, func(t *testing.T) {

initialComp := wizard.Compile(tc.DefineFunc)

disc.Analyze(initialComp)

moduleAComp := wizard.Compile(func(build *wizard.Builder) {

for _, colName := range initialComp.Columns.AllKeys() {

col := initialComp.Columns.GetHandle(colName)
if !disc.ColumnIsInModule(col, moduleAName) {
continue
}

build.RegisterCommit(col.GetColID(), col.Size())
}
}, dummy.CompileAtProverLvl)

var (
_ = dist_permutation.NewPermutationIntoGrandProductCtx(
dist_permutation.Settings{TargetModuleName: moduleAName},
initialComp, moduleAComp, &disc,
)
initialRun *wizard.ProverRuntime
)

// This function assigns the initial module and is aimed at working
// for all test-case.
initialProve := func(run *wizard.ProverRuntime) {
for _, colName := range run.Spec.Columns.AllKeys() {
run.AssignColumn(colName, smartvectors.ForTest(1, 2, 3, 4))
}

initialRun = run
}

_ = wizard.Prove(initialComp, initialProve)
// initialComp is defined according to the define function provided by the
// test-case.
initialComp := wizard.Compile(tc.DefineFunc)

disc := namebaseddiscoverer.PeriodSeperatingModuleDiscoverer{}
disc.Analyze(initialComp)

moduleAProve := func(run *wizard.ProverRuntime) {
for _, colName := range initialComp.Columns.AllKeys() {
// This declares a compiled IOP with only the columns of the module A
moduleAComp := distributed.GetFreshModuleComp(initialComp, &disc, moduleAName)

col := initialComp.Columns.GetHandle(colName)
if !disc.ColumnIsInModule(col, moduleAName) {
continue
}
// This distributes the permutation queries
dist_permutation.NewPermutationIntoGrandProductCtx(
dist_permutation.Settings{TargetModuleName: moduleAName},
initialComp, moduleAComp, &disc,
)

c := initialRun.GetColumn(colName)
run.AssignColumn(colName, c)
}
}
// This runs the initial prover
initialRuntime := wizard.RunProver(initialComp, initialProve)

proof := wizard.Prove(moduleAComp, moduleAProve)
proof := wizard.Prove(moduleAComp, func(run *wizard.ProverRuntime) {
run.ParentRuntime = initialRuntime
})
valid := wizard.Verify(moduleAComp, proof)
require.NoError(t, valid)

Expand Down
4 changes: 2 additions & 2 deletions prover/protocol/distributed/compiler/permutation/settings.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package dist_permutation

import modulediscoverer "github.com/consensys/linea-monorepo/prover/protocol/distributed/module_discoverer"
import "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer"

type Settings struct {
// Name of the target module
TargetModuleName modulediscoverer.ModuleName
TargetModuleName namebaseddiscoverer.ModuleName
}
4 changes: 2 additions & 2 deletions prover/protocol/distributed/distributed.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type ModuleDiscoverer interface {
// group best the columns into modules.
Analyze(comp *wizard.CompiledIOP)
NbModules() int
ModuleList(comp *wizard.CompiledIOP) []ModuleName
ModuleList() []ModuleName
FindModule(col ifaces.Column) ModuleName
// given a query and a module name it checks if the query is inside the module
ExpressionIsInModule(*symbolic.Expression, ModuleName) bool
Expand All @@ -50,7 +50,7 @@ func Distribute(initialWizard *wizard.CompiledIOP, disc ModuleDiscoverer, maxSeg
// analyze the initialWizard to split it to modules.
disc.Analyze(initialWizard)

moduleLs := disc.ModuleList(initialWizard)
moduleLs := disc.ModuleList()
distModules := []DistributedModule{}

for _, modName := range moduleLs {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package distributed
package namebaseddiscoverer

import (
"strings"
Expand Down Expand Up @@ -60,7 +60,7 @@ func (p *PeriodSeperatingModuleDiscoverer) NbModules() int {
}

// ModuleList returns the list of module names
func (p *PeriodSeperatingModuleDiscoverer) ModuleList(comp *wizard.CompiledIOP) []ModuleName {
func (p *PeriodSeperatingModuleDiscoverer) ModuleList() []ModuleName {
moduleNames := make([]ModuleName, 0, len(p.modules))
for moduleName := range p.modules {
moduleNames = append(moduleNames, moduleName)
Expand Down
Loading