Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion prover/protocol/distributed/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@ import (
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/consensys/linea-monorepo/prover/symbolic"
"github.com/consensys/linea-monorepo/prover/utils"
"github.com/consensys/linea-monorepo/prover/utils/collection"
)

// ReplaceExternalCoins replaces the external coins with local coins, for a given expression.
// It does not check if all the columns from the expression are in the module.
// If this is required should be check before calling ReplaceExternalCoins.
// If the Coin does not exist in the initialComp it panics.
func ReplaceExternalCoins(initialComp, moduleComp *wizard.CompiledIOP, expr *symbolic.Expression) {
// It adds the coins to the translationMap.
func ReplaceExternalCoins(
initialComp, moduleComp *wizard.CompiledIOP,
expr *symbolic.Expression,
translationMap collection.Mapping[string, *symbolic.Expression],
) {
var (
board = expr.Board()
metadata = board.ListVariableMetadata()
Expand All @@ -29,6 +35,7 @@ func ReplaceExternalCoins(initialComp, moduleComp *wizard.CompiledIOP, expr *sym
}
if !moduleComp.Coins.Exists(v.Name) {
moduleComp.InsertCoin(v.Round, v.Name, coin.Field)
translationMap.InsertNew(v.String(), symbolic.NewVariable(v))
}
}
}
Expand Down
92 changes: 92 additions & 0 deletions prover/protocol/distributed/comp_splitting.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package distributed

import (
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/consensys/linea-monorepo/prover/utils"
)

// SegmentModuleInputs stores the inputs for both
// vertical and horizontal splitting of a [wizard.CompiledIOP] object.
type SegmentModuleInputs struct {
// InitialComp subject to the splitting
InitialComp *wizard.CompiledIOP
// inputs for horizontal splitting
Disc ModuleDiscoverer
ModuleName ModuleName
// inputs for vertical splitting
NumSegmentsInModule int
}

// GetFreshSegmentModuleComp returns a [wizard.DefineFunc] that creates
// a [wizard.CompiledIOP] object including only the columns relevant to the module.
// It splits the columns to the segments and assign them to the relevant CompiledIOP.
// It also contains the prover steps for assigning the module column.
// For all the segments from the same module, compiledIOP object is the same.
func GetFreshSegmentModuleComp(in SegmentModuleInputs) *wizard.CompiledIOP {

var (
// initialize the moduleComp
segModComp = wizard.NewCompiledIOP()
initialComp = in.InitialComp
)

for round := 0; round < initialComp.NumRounds(); round++ {
var columnsInRound []ifaces.Column
// get the columns per round
for _, colName := range initialComp.Columns.AllKeysAt(round) {

col := initialComp.Columns.GetHandle(colName)
if !in.Disc.ColumnIsInModule(col, in.ModuleName) {
continue
}

segModComp.InsertCommit(col.Round(), col.GetColID(), col.Size()/in.NumSegmentsInModule)
columnsInRound = append(columnsInRound, col)
}

// create a new moduleProver
segModuleProver := segmentModuleProver{
cols: columnsInRound,
round: round,
numSegments: in.NumSegmentsInModule,
}

// register Prover action for the segment-module to assign columns per round
segModComp.RegisterProverAction(round, segModuleProver)
}

return segModComp
}

// it stores the input for the module prover
type segmentModuleProver struct {
round int
// columns for a specific round
cols []ifaces.Column
numSegments int
}

// It implements [wizard.ProverAction] for the module prover.
func (p segmentModuleProver) Run(run *wizard.ProverRuntime) {

if run.ParentRuntime == nil {
utils.Panic("invalid call: the runtime does not have a [ParentRuntime]")
}
if run.ProverID > p.numSegments {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should make ProverID a field of segmentModuleProver and not ProverRuntime

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModulCop for all segments should remain the same and proverID should only affect the prover.

On the other hand the run time already has the parent run time and is natural to say what is the position of the prover among its siblings.

panic("proverID can not be larger than number of segments")
}

for _, col := range p.cols {
// get the witness from the initialProver
colWitness := run.ParentRuntime.GetColumn(col.GetColID())
colSegWitness := getSegmentFromWitness(colWitness, p.numSegments, run.ProverID)
// assign it in the module in the round col was declared
run.AssignColumn(col.GetColID(), colSegWitness, col.Round())
}
}

func getSegmentFromWitness(wit ifaces.ColAssignment, numSegs, segID int) ifaces.ColAssignment {
segSize := wit.Len() / numSegs
return wit.SubVector(segSize*segID, segSize*segID+segSize)
}
59 changes: 48 additions & 11 deletions prover/protocol/distributed/compiler/inclusion/inclusion.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/consensys/linea-monorepo/prover/symbolic"
"github.com/consensys/linea-monorepo/prover/utils"
"github.com/consensys/linea-monorepo/prover/utils/collection"
)

const (
Expand Down Expand Up @@ -70,20 +71,24 @@ func DistributeLogDerivativeSum(
func GetShareOfLogDerivativeSum(in DistributionInputs) {

var (
initialComp = in.InitialComp
moduleComp = in.ModuleComp
numerator []*symbolic.Expression
denominator []*symbolic.Expression
keyIsInModule bool
zCatalog = make(map[int]*query.LogDerivativeSumInput)
logDeriv = in.Query
round = logDeriv.Round
// create a translation map from the columns of moduleComp.
translationMap = createTranslationMap(moduleComp)
)

// extract the share of the module from the global sum.
for size := range logDeriv.Inputs {

for i := range logDeriv.Inputs[size].Numerator {
var (
numerator []*symbolic.Expression
denominator []*symbolic.Expression
)

for i, num := range logDeriv.Inputs[size].Numerator {

// if Denominator is in the module pass the numerator from initialComp to moduleComp
// Particularly, T might be in the module and needs to take M from initialComp.
Expand All @@ -92,22 +97,36 @@ func GetShareOfLogDerivativeSum(in DistributionInputs) {
if !in.Disc.ExpressionIsInModule(logDeriv.Inputs[size].Numerator[i], in.ModuleName) {
utils.Panic("Denominator is in the module but not Numerator")
}
// update translationMap by adding local coins
// the previous check guarantees that all the columns
// from the expression are in the module
// Thus we can add the coins locally (i.e., without [distributed.ModuleDiscoverer]).
distributed.ReplaceExternalCoins(in.InitialComp, moduleComp, logDeriv.Inputs[size].Denominator[i], translationMap)

denominator = append(denominator, logDeriv.Inputs[size].Denominator[i])
numerator = append(numerator, logDeriv.Inputs[size].Numerator[i])
denominator = append(denominator,
// get the corresponding expression from the module
// this is mainly for adjusting the size of expressions
// in the module-segments.
logDeriv.Inputs[size].Denominator[i].Replay(translationMap),
)

numerator = append(numerator, num.Replay(translationMap))

// replaces the external coins with local coins
// note that they just appear in the denominator.
distributed.ReplaceExternalCoins(initialComp, moduleComp, logDeriv.Inputs[size].Denominator[i])
keyIsInModule = true
}
}

// if there in any expression relevant to the current key, add them to zCatalog
if keyIsInModule {

board := denominator[0].Board()
// size of the expressions in the module
sizeInModule := column.ExprIsOnSameLengthHandles(&board)

// zCatalog specific to the module
zCatalog[size] = &query.LogDerivativeSumInput{
Size: size,
// due to vertical splitting size in module-segments may be different from size in the initialComp.
zCatalog[sizeInModule] = &query.LogDerivativeSumInput{
Size: sizeInModule,
Numerator: numerator,
Denominator: denominator,
}
Expand Down Expand Up @@ -169,3 +188,21 @@ func getLogDerivativeSumResult(zCatalog map[int]*query.LogDerivativeSumInput, ru
}
return actualSum
}

func createTranslationMap(comp *wizard.CompiledIOP) collection.Mapping[string, *symbolic.Expression] {

var (
exprMap = collection.NewMapping[string, *symbolic.Expression]()
expr *symbolic.Expression
)

for _, colID := range comp.Columns.AllKeys() {
expr = ifaces.ColumnAsVariable(comp.Columns.GetHandle(colID))
exprMap.InsertNew(string(colID), expr)
}
for _, coinID := range comp.Coins.AllKeys() {
expr = symbolic.NewVariable(comp.Coins.Data(coinID))
exprMap.InsertNew(string(coinID), expr)
}
return exprMap
}
61 changes: 41 additions & 20 deletions prover/protocol/distributed/compiler/inclusion/inclusion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ import (

// It tests DistributedLogDerivSum.
func TestDistributedLogDerivSum(t *testing.T) {
const (
numSegModule0 = 2
numSegModule1 = 2
)

//initialComp
define := func(b *wizard.Builder) {
Expand All @@ -43,43 +47,60 @@ func TestDistributedLogDerivSum(t *testing.T) {
}

// in initialComp replace inclusion queries with a global LogDerivativeSum
// it also create new columns relevant to the preparation such as multiplicity columns.
// it also creates new columns relevant to the preparation such as multiplicity columns.
initialComp := wizard.Compile(define, distributed.IntoLogDerivativeSum)

// Initialize the period separating module discoverer
disc := &md.PeriodSeperatingModuleDiscoverer{}
disc.Analyze(initialComp)

// distribute the columns among modules; this includes also multiplicity columns
moduleComp0 := distributed.GetFreshModuleComp(initialComp, disc, "module0")
moduleComp1 := distributed.GetFreshModuleComp(initialComp, disc, "module1")
// distribute the columns among modules and segments; this includes also multiplicity columns
// for all the segments from the same module, compiledIOP object is the same.
moduleComp0 := distributed.GetFreshSegmentModuleComp(
distributed.SegmentModuleInputs{
InitialComp: initialComp,
Disc: disc,
ModuleName: "module0",
NumSegmentsInModule: numSegModule0,
},
)
moduleComp1 := distributed.GetFreshSegmentModuleComp(distributed.SegmentModuleInputs{
InitialComp: initialComp,
Disc: disc,
ModuleName: "module1",
NumSegmentsInModule: numSegModule1,
})

// distribute the query LogDerivativeSum among modules.
inclusion.DistributeLogDerivativeSum(initialComp, moduleComp0, "module0", disc)
inclusion.DistributeLogDerivativeSum(initialComp, moduleComp1, "module1", disc)

// This compiles the log-derivative queries into global/local queries.
logderiv.CompileLogDerivSum(moduleComp0)
logderiv.CompileLogDerivSum(moduleComp1)

// This adds a dummy compilation step to control that all passes
dummy.CompileAtProverLvl(moduleComp0)
dummy.CompileAtProverLvl(moduleComp1)
wizard.ContinueCompilation(moduleComp0, logderiv.CompileLogDerivSum, dummy.Compile)
wizard.ContinueCompilation(moduleComp1, logderiv.CompileLogDerivSum, dummy.Compile)

// run the initial runtime
initialRuntime := wizard.RunProver(initialComp, prover)

// Compile and prove for module0
proof0 := wizard.Prove(moduleComp0, func(run *wizard.ProverRuntime) {
run.ParentRuntime = initialRuntime
})
valid := wizard.Verify(moduleComp0, proof0)
require.NoError(t, valid)
for proverID := 0; proverID < numSegModule0; proverID++ {
proof0 := wizard.Prove(moduleComp0, func(run *wizard.ProverRuntime) {
run.ParentRuntime = initialRuntime
// inputs for vertical splitting of the witness
run.ProverID = proverID
})
valid := wizard.Verify(moduleComp0, proof0)
require.NoError(t, valid)
}

// Compile and prove for module1
proof1 := wizard.Prove(moduleComp1, func(run *wizard.ProverRuntime) {
run.ParentRuntime = initialRuntime
})
valid1 := wizard.Verify(moduleComp1, proof1)
require.NoError(t, valid1)
for proverID := 0; proverID < numSegModule1; proverID++ {
proof1 := wizard.Prove(moduleComp1, func(run *wizard.ProverRuntime) {
run.ParentRuntime = initialRuntime
// inputs for vertical splitting of the witness
run.ProverID = proverID
})
valid1 := wizard.Verify(moduleComp1, proof1)
require.NoError(t, valid1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func periodLogicToDetermineModule(col ifaces.Column) ModuleName {
colName := col.GetColID()
// for multiplicity Column it is "TABLE_moduleName." So we should separate the ModuleName from this.
name := ModuleName(periodSeparator(string(colName)))
index := strings.Index(name, "_")
index := strings.LastIndex(name, "_")
if index != -1 {
name = name[index+1:]
}
Expand Down
17 changes: 11 additions & 6 deletions prover/protocol/wizard/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,27 @@ Compile an IOP from a protocol definition
func Compile(define DefineFunc, compilers ...func(*CompiledIOP)) *CompiledIOP {
builder := newBuilder()
define(&builder)
comp := builder.CompiledIOP
return ContinueCompilation(comp, compilers...)
}

// ContinueCompilation continues a set of compilation steps over a initial CompiledIOP object.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good idea to add that IMO

func ContinueCompilation(rootComp *CompiledIOP, compilers ...func(*CompiledIOP)) *CompiledIOP {
/*
For sanity, we need to ensure the protocol is well formed. All
registers should have the same number of rounds. The simplest to
iron this out after the define function. We still make sure than
no more rounds are allocated anywhere.
*/
comp := builder.CompiledIOP
comp := rootComp
numRounds := comp.NumRounds()

builder.equalizeRounds(numRounds)
comp.equalizeRounds(numRounds)

for _, compiler := range compilers {
compiler(comp)
numRounds := comp.NumRounds()
builder.equalizeRounds(numRounds)
comp.equalizeRounds(numRounds)
}

if comp.SubProvers.Len() < comp.NumRounds() {
Expand All @@ -75,7 +81,7 @@ func Compile(define DefineFunc, compilers ...func(*CompiledIOP)) *CompiledIOP {
)
}

return builder.CompiledIOP
return comp
}

// NewCompiledIOP initializes a CompiledIOP object.
Expand Down Expand Up @@ -258,8 +264,7 @@ func (b *Builder) LocalOpening(name ifaces.QueryID, pol ifaces.Column) query.Loc
Equalizes the length of all the structure so that they all have the same
numbers of rounds
*/
func (b *Builder) equalizeRounds(numRounds int) {
comp := b.CompiledIOP
func (comp *CompiledIOP) equalizeRounds(numRounds int) {

helpMsg := "If you are seeing this message it's probably because you insert queries one round too late."

Expand Down
5 changes: 5 additions & 0 deletions prover/protocol/wizard/prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ type ProverRuntime struct {
// in the distributed prover by the module runtimes to access the initial
// wizard runtime.
ParentRuntime *ProverRuntime

// ProverID indicates the id of the prover among its siblings.
// It is merely in the context of the distributed prover;
// for vertical splitting to extract the relevant segment of a witness.
ProverID int

// FiatShamirHistory tracks the fiat-shamir state at the beginning of every
// round. The first entry is the initial state, the final entry is the final
Expand Down