diff --git a/prover/crypto/vortex/opening_prover.go b/prover/crypto/vortex/opening_prover.go index 03e8c9c24ec..64f48a70a21 100644 --- a/prover/crypto/vortex/opening_prover.go +++ b/prover/crypto/vortex/opening_prover.go @@ -41,7 +41,6 @@ type OpeningProof struct { // functions and is motivated by the fact that this is simpler to construct in // our settings. func (params *Params) InitOpeningWithLC(committedSV []smartvectors.SmartVector, randomCoin field.Element) *OpeningProof { - proof := OpeningProof{} if len(committedSV) == 0 { utils.Panic("attempted to open an empty witness") @@ -55,15 +54,48 @@ func (params *Params) InitOpeningWithLC(committedSV []smartvectors.SmartVector, for i := range committedSV { subTask = append(subTask, committedSV[i].SubVector(start, stop)) } - // Collect the result in the larger slice at the end + // Collect the result in the larger slice at the end subResult := smartvectors.PolyEval(subTask, randomCoin) subResult.WriteInSlice(linComb[start:stop]) }) linCombSV := smartvectors.NewRegular(linComb) - proof.LinearCombination = params.rsEncode(linCombSV, nil) - return &proof + + return &OpeningProof{ + LinearCombination: params.rsEncode(linCombSV, nil), + } +} + +// InitOpeningFromAlreadyEncodedLC initiates the construction of a Vortex proof +// by returning the encoding of the linear combinations of the committed +// row-vectors contained in committedSV by the successive powers of randomCoin. +// +// The returned proof is partially assigned and must be completed using +// [WithEntryList] to conclude the opening protocol. +func (params *Params) InitOpeningFromAlreadyEncodedLC(rsCommittedSV EncodedMatrix, randomCoin field.Element) *OpeningProof { + + if len(rsCommittedSV) == 0 { + utils.Panic("attempted to open an empty witness") + } + + // Compute the linear combination + linComb := make([]field.Element, params.NumEncodedCols()) + + parallel.ExecuteChunky(len(linComb), func(start, stop int) { + subTask := make([]smartvectors.SmartVector, 0, len(rsCommittedSV)) + for i := range rsCommittedSV { + subTask = append(subTask, rsCommittedSV[i].SubVector(start, stop)) + } + + // Collect the result in the larger slice at the end + subResult := smartvectors.PolyEval(subTask, randomCoin) + subResult.WriteInSlice(linComb[start:stop]) + }) + + return &OpeningProof{ + LinearCombination: smartvectors.NewRegular(linComb), + } } // Complete completes the proof adding the columns pointed by entryList diff --git a/prover/protocol/column/status.go b/prover/protocol/column/status.go index 2e0838c9e22..f8d372be101 100644 --- a/prover/protocol/column/status.go +++ b/prover/protocol/column/status.go @@ -55,7 +55,7 @@ const ( // protocol. Meaning that this is not part of the proof. // // Deprecated: we don't really use this to create public inputs. - PublicInput + _ // VerifyingKey indicates the column is defined offline during the definition // of the protocol or the compilation and that the column is directly // available to the verifier. It is preferable to avoid tagging large @@ -82,8 +82,6 @@ func (s Status) String() string { return "PROOF" case Precomputed: return "PRECOMPUTED" - case PublicInput: - return "PUBLIC_INPUT" case VerifyingKey: return "VERIFYING_KEY" case VerifierDefined: @@ -95,7 +93,7 @@ func (s Status) String() string { // IsPublic returns true if the column is visible to the verifier func (s Status) IsPublic() bool { switch s { - case Proof, PublicInput, VerifyingKey, VerifierDefined: + case Proof, VerifyingKey, VerifierDefined: return true default: return false diff --git a/prover/protocol/column/store.go b/prover/protocol/column/store.go index bbcb4ad0862..0a7a8f18bb4 100644 --- a/prover/protocol/column/store.go +++ b/prover/protocol/column/store.go @@ -49,6 +49,12 @@ type storedColumnInfo struct { // FullRecursion. This field is only meaningfull for [Ignored] columns as // they are excluded by default. IncludeInProverFS bool + // ExcludeFromProverFS states the prover should not include the column in + // his FS transcript. This overrides [IncludeInProverFS], meaning that if + // [IncludeInProverFS] is true but ExcludeFromProverFS is true, the column + // will still be excluded from the transcript. This is used explicit FS + // compilation. + ExcludeFromProverFS bool } // AddToRound constructs a [Natural], registers it in the [Store] and returns @@ -165,19 +171,6 @@ func (r *Store) AllKeysProof() []ifaces.ColID { return res } -// AllKeysPublicInput returns the list of the [PublicInput] column's ID ordered -// by rounds and then by order ot insertion. -func (r *Store) AllKeysPublicInput() []ifaces.ColID { - res := []ifaces.ColID{} - - for round := 0; round < r.NumRounds(); round++ { - proof := r.AllKeysPublicInputAt(round) - res = append(res, proof...) - } - - return res -} - // AllKeysCommitted returns the list of all the IDs of the all the [Committed] // columns ordered by rounds and then by IDs. func (r *Store) AllKeysCommitted() []ifaces.ColID { @@ -227,22 +220,6 @@ func (r *Store) AllKeysProofAt(round int) []ifaces.ColID { return res } -// AllKeysPublicInputAt returns the list of all the prover messages in a given -// round. The resulting slice is ordered by order of insertion. -func (r *Store) AllKeysPublicInputAt(round int) []ifaces.ColID { - res := []ifaces.ColID{} - rnd := r.byRounds.MustGet(round) - - for i, info := range rnd { - if info.Status != PublicInput { - continue - } - res = append(res, rnd[i].ID) - } - - return res -} - // Returns the list of all the [Precomputed] columns' ID. The returned slice is // ordered by rounds and then by order of insertion. func (r *Store) AllPrecomputed() []ifaces.ColID { @@ -442,10 +419,6 @@ func assertCorrectStatusTransition(old, new Status) { // If it's ignored, it's ignored case old == Ignored && new != Ignored: forbiddenTransition = true - // You can't change the status of the public inputs because that would - // change the statement of the zkEVM. - case old == PublicInput && new != PublicInput: - forbiddenTransition = true // It's a special status and cannot be changed. case old == VerifierDefined && new != VerifierDefined: forbiddenTransition = true @@ -466,23 +439,46 @@ func (s *Store) IgnoreButKeepInProverTranscript(colName ifaces.ColID) { in.IncludeInProverFS = true } -// IsIgnoredAndNotKeptInTranscript indicates whether the column can be ignored -// from the transcript and is used during the Fiat-Shamir randomness generation. -func (s *Store) IsIgnoredAndNotKeptInTranscript(colName ifaces.ColID) bool { +// ExcludeFromProverFS marks a column as excluded from the FS transcript but +// without changing its status. This is used as part of the conglomeration +// where the imported columns take part in a separate FS transcript from the +// canonical of the host wizard. +func (s *Store) ExcludeFromProverFS(colName ifaces.ColID) { in := s.info(colName) - return in.Status == Ignored && !in.IncludeInProverFS + in.ExcludeFromProverFS = true +} + +// isExcludedFromProverFS returns true if the passed column ID relates to a column +// that does not take part in the FS transcript. +func (in *storedColumnInfo) isExcludedFromProverFS() bool { + + if in.ExcludeFromProverFS { + return true + } + + if in.IncludeInProverFS { + return false + } + + return true +} + +// IsExplicitlyExcludedFromProverFS returns true if the passed column ID relates to +// a column explicitly marked as excluded from the FS transcript. +func (s *Store) IsExplicitlyExcludedFromProverFS(colName ifaces.ColID) bool { + info := s.info(colName) + return info.ExcludeFromProverFS } -// AllKeysProofsOrIgnoredButKeptInProverTranscript returns the list of the -// columns to be used as part of the FS transcript. -func (s *Store) AllKeysProofsOrIgnoredButKeptInProverTranscript(round int) []ifaces.ColID { +// AllKeysInProverTranscript returns the list of the columns to +// be used as part of the FS transcript. +func (s *Store) AllKeysInProverTranscript(round int) []ifaces.ColID { res := []ifaces.ColID{} rnd := s.byRounds.MustGet(round) // precomputed are always at round zero for i, info := range rnd { - ok := (info.Status == Proof) || (info.Status == Ignored && info.IncludeInProverFS) - if !ok { + if info.isExcludedFromProverFS() { continue } diff --git a/prover/protocol/column/verifiercol/from_alleged_ys.go b/prover/protocol/column/verifiercol/from_alleged_ys.go index 6d332fffb23..8a9cce0bc5c 100644 --- a/prover/protocol/column/verifiercol/from_alleged_ys.go +++ b/prover/protocol/column/verifiercol/from_alleged_ys.go @@ -9,7 +9,7 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" - "github.com/sirupsen/logrus" + "github.com/consensys/linea-monorepo/prover/utils" ) // compile check to enforce the struct to belong to the corresponding interface @@ -39,10 +39,9 @@ func NewFromYs(comp *wizard.CompiledIOP, q query.UnivariateEval, ranges []ifaces nameMap[polName.GetColID()] = struct{}{} } - // No make the explicit check for _, rangeName := range ranges { if _, ok := nameMap[rangeName]; !ok && !strings.Contains(string(rangeName), "SHADOW") { - logrus.Debugf("NewFromYs : %v is not part of the query %v. It will be zeroized", rangeName, q.QueryID) + utils.Panic("NewFromYs : %v is not part of the query %v", rangeName, q.QueryID) } } diff --git a/prover/protocol/compiler/dummy/dummy.go b/prover/protocol/compiler/dummy/dummy.go index c0b3415e11b..e28f0e856fe 100644 --- a/prover/protocol/compiler/dummy/dummy.go +++ b/prover/protocol/compiler/dummy/dummy.go @@ -84,7 +84,7 @@ func Compile(comp *wizard.CompiledIOP) { One step to be run at the end, by verifying every constraint "a la mano" */ - verifier := func(run *wizard.VerifierRuntime) error { + verifier := func(run wizard.Runtime) error { logrus.Infof("started to run the dummy verifier") @@ -137,6 +137,6 @@ func Compile(comp *wizard.CompiledIOP) { } logrus.Debugf("NB: The gnark circuit does not check the verifier of the dummy reduction\n") - comp.InsertVerifier(numRounds-1, verifier, func(frontend.API, *wizard.WizardVerifierCircuit) {}) + comp.InsertVerifier(numRounds-1, verifier, func(frontend.API, wizard.GnarkRuntime) {}) } diff --git a/prover/protocol/compiler/fullrecursion/actions.go b/prover/protocol/compiler/fullrecursion/actions.go index 44927aa99a9..228287cd309 100644 --- a/prover/protocol/compiler/fullrecursion/actions.go +++ b/prover/protocol/compiler/fullrecursion/actions.go @@ -73,11 +73,11 @@ func (c LocalOpeningAssignment) Run(run *wizard.ProverRuntime) { } } -func (c *ConsistencyCheck) Run(run *wizard.VerifierRuntime) error { +func (c *ConsistencyCheck) Run(run wizard.Runtime) error { var ( initialFsCirc = run.GetLocalPointEvalParams(c.LocalOpenings[0].ID).Y - initialFsRt = run.FiatShamirHistory[c.FirstRound+1][0][0] + initialFsRt = run.FsHistory()[c.FirstRound+1][0][0] piCursor = 2 ) @@ -131,11 +131,11 @@ func (c *ConsistencyCheck) Run(run *wizard.VerifierRuntime) error { return nil } -func (c *ConsistencyCheck) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { +func (c *ConsistencyCheck) RunGnark(api frontend.API, run wizard.GnarkRuntime) { var ( initialFsCirc = run.GetLocalPointEvalParams(c.LocalOpenings[0].ID).Y - initialFsRt = run.FiatShamirHistory[c.FirstRound+1][0][0] + initialFsRt = run.FsHistory()[c.FirstRound+1][0][0] piCursor = 2 ) @@ -187,15 +187,15 @@ func (c *ConsistencyCheck) IsSkipped() bool { return c.isSkipped } -func (r *ResetFsActions) Run(run *wizard.VerifierRuntime) error { +func (r *ResetFsActions) Run(run wizard.Runtime) error { finalFsCirc := run.GetLocalPointEvalParams(r.LocalOpenings[1].ID).Y - run.FS.SetState([]field.Element{finalFsCirc}) + run.Fs().SetState([]field.Element{finalFsCirc}) return nil } -func (r *ResetFsActions) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { +func (r *ResetFsActions) RunGnark(api frontend.API, run wizard.GnarkRuntime) { finalFsCirc := run.GetLocalPointEvalParams(r.LocalOpenings[1].ID).Y - run.FS.SetState([]frontend.Variable{finalFsCirc}) + run.Fs().SetState([]frontend.Variable{finalFsCirc}) } func (r *ResetFsActions) Skip() { diff --git a/prover/protocol/compiler/fullrecursion/circuit.go b/prover/protocol/compiler/fullrecursion/circuit.go index 2151dfa7f33..9e92c4515f6 100644 --- a/prover/protocol/compiler/fullrecursion/circuit.go +++ b/prover/protocol/compiler/fullrecursion/circuit.go @@ -20,7 +20,7 @@ type gnarkCircuit struct { X frontend.Variable `gnark:",public"` Ys []frontend.Variable `gnark:",public"` Pubs []frontend.Variable `gnark:",public"` - WizardVerifier *wizard.WizardVerifierCircuit + WizardVerifier wizard.GnarkRuntime comp *wizard.CompiledIOP `gnark:"-"` ctx *fullRecursionCtx `gnark:"-"` withoutGkr bool `gnark:"-"` @@ -66,7 +66,7 @@ func allocateGnarkCircuit(comp *wizard.CompiledIOP, ctx *fullRecursionCtx) *gnar func (c *gnarkCircuit) Define(api frontend.API) error { - w := c.WizardVerifier + w := c.WizardVerifier.(*wizard.WizardVerifierCircuit) if c.withoutGkr { w.FS = fiatshamir.NewGnarkFiatShamir(api, nil) @@ -116,7 +116,7 @@ func (c *gnarkCircuit) generateAllRandomCoins(api frontend.API) { var ( ctx = c.ctx - w = c.WizardVerifier + w = c.WizardVerifier.(*wizard.WizardVerifierCircuit) ) w.FS.SetState([]frontend.Variable{c.InitialFsState}) @@ -129,6 +129,11 @@ func (c *gnarkCircuit) generateAllRandomCoins(api frontend.API) { toUpdateFS := ctx.Columns[currRound-1] for _, msg := range toUpdateFS { + + if c.comp.Columns.IsExplicitlyExcludedFromProverFS(msg.GetColID()) { + continue + } + val := w.GetColumn(msg.GetColID()) w.FS.UpdateVec(val) } diff --git a/prover/protocol/compiler/globalcs/evaluation.go b/prover/protocol/compiler/globalcs/evaluation.go index b612b78ec30..2fe167f50d1 100644 --- a/prover/protocol/compiler/globalcs/evaluation.go +++ b/prover/protocol/compiler/globalcs/evaluation.go @@ -165,7 +165,7 @@ func (pa evaluationProver) Run(run *wizard.ProverRuntime) { } // Run evaluate the constraint and checks that -func (ctx *evaluationVerifier) Run(run *wizard.VerifierRuntime) error { +func (ctx *evaluationVerifier) Run(run wizard.Runtime) error { var ( // Will be assigned to "X", the random point at which we check the constraint. @@ -239,7 +239,7 @@ func (ctx *evaluationVerifier) Run(run *wizard.VerifierRuntime) error { } // Verifier step, evaluate the constraint and checks that -func (ctx *evaluationVerifier) RunGnark(api frontend.API, c *wizard.WizardVerifierCircuit) { +func (ctx *evaluationVerifier) RunGnark(api frontend.API, c wizard.GnarkRuntime) { // Will be assigned to "X", the random point at which we check the constraint. r := c.GetRandomCoinField(ctx.EvalCoin.Name) @@ -299,7 +299,7 @@ func (ctx *evaluationVerifier) RunGnark(api frontend.API, c *wizard.WizardVerifi // recombineQuotientSharesEvaluation returns the evaluations of the quotients // on point r -func (ctx evaluationVerifier) recombineQuotientSharesEvaluation(run *wizard.VerifierRuntime, r field.Element) ([]field.Element, error) { +func (ctx evaluationVerifier) recombineQuotientSharesEvaluation(run wizard.Runtime, r field.Element) ([]field.Element, error) { var ( // res stores the list of the recombined quotient evaluations for each @@ -386,7 +386,7 @@ func (ctx evaluationVerifier) recombineQuotientSharesEvaluation(run *wizard.Veri // recombineQuotientSharesEvaluation returns the evaluations of the quotients // on point r -func (ctx evaluationVerifier) recombineQuotientSharesEvaluationGnark(api frontend.API, run *wizard.WizardVerifierCircuit, r frontend.Variable) []frontend.Variable { +func (ctx evaluationVerifier) recombineQuotientSharesEvaluationGnark(api frontend.API, run wizard.GnarkRuntime, r frontend.Variable) []frontend.Variable { var ( // res stores the list of the recombined quotient evaluations for each diff --git a/prover/protocol/compiler/grandproduct/compiler.go b/prover/protocol/compiler/grandproduct/compiler.go index 370a443d01a..8d7f2a58559 100644 --- a/prover/protocol/compiler/grandproduct/compiler.go +++ b/prover/protocol/compiler/grandproduct/compiler.go @@ -101,7 +101,7 @@ type FinalProductCheck struct { } // Run implements the [wizard.VerifierAction] -func (f *FinalProductCheck) Run(run *wizard.VerifierRuntime) error { +func (f *FinalProductCheck) Run(run wizard.Runtime) error { // zProd stores the product of the ending values of the zs as queried // in the protocol via the local opening queries. @@ -122,7 +122,7 @@ func (f *FinalProductCheck) Run(run *wizard.VerifierRuntime) error { } // RunGnark implements the [wizard.VerifierAction] -func (f *FinalProductCheck) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { +func (f *FinalProductCheck) RunGnark(api frontend.API, run wizard.GnarkRuntime) { claimedProd := run.GetGrandProductParams(f.GrandProductID).Prod // zProd stores the product of the ending values of the z columns diff --git a/prover/protocol/compiler/innerproduct/verifier.go b/prover/protocol/compiler/innerproduct/verifier.go index accc5acd0e9..cb0c21e06c3 100644 --- a/prover/protocol/compiler/innerproduct/verifier.go +++ b/prover/protocol/compiler/innerproduct/verifier.go @@ -24,7 +24,7 @@ type verifierForSize struct { } // Run implements [wizard.VerifierAction] -func (v *verifierForSize) Run(run *wizard.VerifierRuntime) error { +func (v *verifierForSize) Run(run wizard.Runtime) error { var ( // ys stores the list of all the inner-product openings @@ -59,7 +59,7 @@ func (v *verifierForSize) Run(run *wizard.VerifierRuntime) error { } // RunGnark implements the [wizard.VerifierAction] interface -func (v *verifierForSize) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { +func (v *verifierForSize) RunGnark(api frontend.API, run wizard.GnarkRuntime) { var ( // ys stores the list of all the inner-product openings diff --git a/prover/protocol/compiler/logderivativesum/compile.go b/prover/protocol/compiler/logderivativesum/compile.go index b6796b96866..bf584483dd9 100644 --- a/prover/protocol/compiler/logderivativesum/compile.go +++ b/prover/protocol/compiler/logderivativesum/compile.go @@ -75,7 +75,7 @@ type FinalEvaluationCheck struct { } // Run implements the [wizard.VerifierAction] -func (f *FinalEvaluationCheck) Run(run *wizard.VerifierRuntime) error { +func (f *FinalEvaluationCheck) Run(run wizard.Runtime) error { // zSum stores the sum of the ending values of the zs as queried // in the protocol via the local opening queries. @@ -96,7 +96,7 @@ func (f *FinalEvaluationCheck) Run(run *wizard.VerifierRuntime) error { } // RunGnark implements the [wizard.VerifierAction] -func (f *FinalEvaluationCheck) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { +func (f *FinalEvaluationCheck) RunGnark(api frontend.API, run wizard.GnarkRuntime) { claimedSum := run.GetLogDerivSumParams(f.LogDerivSumID).Sum // SigmaSKSum stores the sum of the ending values of the SigmaSs as queried diff --git a/prover/protocol/compiler/lookup/verifier.go b/prover/protocol/compiler/lookup/verifier.go index be3219d085d..042c6262cd9 100644 --- a/prover/protocol/compiler/lookup/verifier.go +++ b/prover/protocol/compiler/lookup/verifier.go @@ -25,7 +25,7 @@ type finalEvaluationCheck struct { } // Run implements the [wizard.VerifierAction] -func (f *finalEvaluationCheck) Run(run *wizard.VerifierRuntime) error { +func (f *finalEvaluationCheck) Run(run wizard.Runtime) error { // zSum stores the sum of the ending values of the zs as queried // in the protocol via the local opening queries. @@ -43,7 +43,7 @@ func (f *finalEvaluationCheck) Run(run *wizard.VerifierRuntime) error { } // RunGnark implements the [wizard.VerifierAction] -func (f *finalEvaluationCheck) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { +func (f *finalEvaluationCheck) RunGnark(api frontend.API, run wizard.GnarkRuntime) { // SigmaSKSum stores the sum of the ending values of the SigmaSs as queried // in the protocol via the diff --git a/prover/protocol/compiler/permutation/verifier.go b/prover/protocol/compiler/permutation/verifier.go index d076e643ba1..5b685d304d0 100644 --- a/prover/protocol/compiler/permutation/verifier.go +++ b/prover/protocol/compiler/permutation/verifier.go @@ -18,7 +18,7 @@ type VerifierCtx struct { // Run implements the [wizard.VerifierAction] interface and checks that the // product of the products given by the ZCtx is equal to one. -func (v *VerifierCtx) Run(run *wizard.VerifierRuntime) error { +func (v *VerifierCtx) Run(run wizard.Runtime) error { mustBeOne := field.One() @@ -38,7 +38,7 @@ func (v *VerifierCtx) Run(run *wizard.VerifierRuntime) error { // Run implements the [wizard.VerifierAction] interface and is as // [VerifierCtx.Run] but in the context of a gnark circuit. -func (v *VerifierCtx) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { +func (v *VerifierCtx) RunGnark(api frontend.API, run wizard.GnarkRuntime) { mustBeOne := frontend.Variable(1) diff --git a/prover/protocol/compiler/projection/verifier.go b/prover/protocol/compiler/projection/verifier.go index 59b4f4d0d5d..7ac500068af 100644 --- a/prover/protocol/compiler/projection/verifier.go +++ b/prover/protocol/compiler/projection/verifier.go @@ -20,7 +20,7 @@ type projectionVerifierAction struct { } // Run implements the [wizard.VerifierAction] interface. -func (va *projectionVerifierAction) Run(run *wizard.VerifierRuntime) error { +func (va *projectionVerifierAction) Run(run wizard.Runtime) error { var ( a = run.GetLocalPointEvalParams(va.HornerA0.ID).Y @@ -35,7 +35,7 @@ func (va *projectionVerifierAction) Run(run *wizard.VerifierRuntime) error { } // RunGnark implements the [wizard.VerifierAction] interface. -func (va *projectionVerifierAction) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { +func (va *projectionVerifierAction) RunGnark(api frontend.API, run wizard.GnarkRuntime) { var ( a = run.GetLocalPointEvalParams(va.HornerA0.ID).Y diff --git a/prover/protocol/compiler/selfrecursion/column_opening.go b/prover/protocol/compiler/selfrecursion/column_opening.go index 1091010eec8..a7eca71020e 100644 --- a/prover/protocol/compiler/selfrecursion/column_opening.go +++ b/prover/protocol/compiler/selfrecursion/column_opening.go @@ -255,8 +255,8 @@ func (ctx *SelfRecursionCtx) linearHashAndMerkle() { } for round := 0; round <= totalNumRounds; round++ { - colSisHashName := ctx.VortexCtx.CommitmentName(round) - colSisHashSV, found := run.State.TryGet(string(colSisHashName)) + colSisHashName := ctx.VortexCtx.SisHashName(round) + colSisHashSV, found := run.State.TryGet(colSisHashName) if !found { // continue with the same committedRound until we meet a non-dry // round or we reach the total number of committed rounds @@ -291,7 +291,7 @@ func (ctx *SelfRecursionCtx) linearHashAndMerkle() { } // Frees the colSisHash - run.State.TryDel(string(colSisHashName)) + run.State.TryDel(colSisHashName) // Increment only if the committedRound is non-dry committedRound++ @@ -366,6 +366,7 @@ func (ctx *SelfRecursionCtx) linearHashAndMerkle() { // And the linear hashing mimcW.CheckLinearHash( ctx.comp, + ctx.linearHashVerificationName(), ctx.Columns.ConcatenatedDhQ, ctx.VortexCtx.SisParams.OutputSize(), leavesSizeUnpadded, @@ -404,7 +405,7 @@ func (ctx *SelfRecursionCtx) collapsingPhase() { // Consistency check between the collapsed preimage and UalphaQ { - left := functionals.CoeffEval( + uAlphaQEval := functionals.CoeffEval( ctx.comp, ctx.constencyUalphaQPreimageLeft(), ctx.Coins.Collapse, @@ -423,7 +424,7 @@ func (ctx *SelfRecursionCtx) collapsingPhase() { evaluation point, we get a bivariate polynomial evaluation */ - right := functionals.EvalCoeffBivariate( + preImageEval := functionals.EvalCoeffBivariate( ctx.comp, ctx.constencyUalphaQPreimageRight(), ctx.Columns.PreimagesCollapse, @@ -434,21 +435,21 @@ func (ctx *SelfRecursionCtx) collapsingPhase() { ) ctx.comp.InsertVerifier( - left.Round(), - func(run *wizard.VerifierRuntime) error { - if left.GetVal(run) != right.GetVal(run) { - l, r := left.GetVal(run), right.GetVal(run) + uAlphaQEval.Round(), + func(run wizard.Runtime) error { + if uAlphaQEval.GetVal(run) != preImageEval.GetVal(run) { + l, r := uAlphaQEval.GetVal(run), preImageEval.GetVal(run) return fmt.Errorf("consistency between u_alpha and the preimage: "+ - "mismatch between left and right %v != %v", + "mismatch between uAlphaQEval=%v preimages=%v", l.String(), r.String(), ) } return nil }, - func(api frontend.API, run *wizard.WizardVerifierCircuit) { + func(api frontend.API, run wizard.GnarkRuntime) { api.AssertIsEqual( - left.GetFrontendVariable(api, run), - right.GetFrontendVariable(api, run), + uAlphaQEval.GetFrontendVariable(api, run), + preImageEval.GetFrontendVariable(api, run), ) }, ) @@ -621,7 +622,7 @@ func (ctx *SelfRecursionCtx) foldPhase() { // And the final check // check the folding of the polynomial is correct - ctx.comp.InsertVerifier(round, func(run *wizard.VerifierRuntime) error { + ctx.comp.InsertVerifier(round, func(run wizard.Runtime) error { // fetch the assignments to edual and dcollapse edual := ctx.Columns.Edual.GetColAssignment(run) @@ -666,7 +667,7 @@ func (ctx *SelfRecursionCtx) foldPhase() { } return nil - }, func(api frontend.API, run *wizard.WizardVerifierCircuit) { + }, func(api frontend.API, run wizard.GnarkRuntime) { // fetch the assignments to edual and dcollapse edual := ctx.Columns.Edual.GetColAssignmentGnark(run) diff --git a/prover/protocol/compiler/selfrecursion/context.go b/prover/protocol/compiler/selfrecursion/context.go index 5226926f964..ae226a8c906 100644 --- a/prover/protocol/compiler/selfrecursion/context.go +++ b/prover/protocol/compiler/selfrecursion/context.go @@ -25,6 +25,11 @@ type SelfRecursionCtx struct { // step. SelfRecursionCnt int + // NamePrefix is a prefix for the names of the items generated by + // the current self-recursion compilation context. Leaving it as + // empty means that no prefix is used. + NamePrefix string + // Accessors Accessors struct { // (EvalBivariate) @@ -197,17 +202,16 @@ type SelfRecursionCtx struct { } } -// Initializes a context for the self recursion -func NewSelfRecursionCxt(comp *wizard.CompiledIOP) SelfRecursionCtx { - - // Extract the vortex context from the compiledIOP though - // the "CryptographicCompilerCtx" - vortexCtx := assertVortexCompiled(comp) +// NewRecursionCtx returns a new recursion context taking a specified +// [vortex.Ctx] and SelfRecursionCount. It can be used for custom use +// of the recursion wizard. +func NewRecursionCtx(comp *wizard.CompiledIOP, vortexCtx *vortex.Ctx, prefix string) SelfRecursionCtx { ctx := SelfRecursionCtx{ comp: comp, VortexCtx: vortexCtx, SelfRecursionCnt: comp.SelfRecursionCount, + NamePrefix: prefix, } // Transport the compilation items of the vortex context into @@ -238,7 +242,7 @@ func NewSelfRecursionCxt(comp *wizard.CompiledIOP) SelfRecursionCtx { // Assume that the rounds commitments have a `Proof` status if comp.Columns.Status(rooth.GetColID()) != column.Proof { utils.Panic( - "Assumed the Dh to be %v but status is %v", + "Assumed the rootH to be %v but status is %v", column.Proof.String(), comp.Columns.Status(rooth.GetColID()), ) @@ -280,6 +284,13 @@ func NewSelfRecursionCxt(comp *wizard.CompiledIOP) SelfRecursionCtx { return ctx } +// Initializes a context for the self recursion +func NewSelfRecursionCxt(comp *wizard.CompiledIOP) SelfRecursionCtx { + // Extract the vortex context from the compiledIOP though the "Pcs" + vortexCtx := assertVortexCompiled(comp) + return NewRecursionCtx(comp, vortexCtx, "") +} + // Asserts that the compiled IOP has the appropriate cryptographic context func assertVortexCompiled(comp *wizard.CompiledIOP) *vortex.Ctx { // When we compiled using Vortex, we annotated the compiledIOP diff --git a/prover/protocol/compiler/selfrecursion/lincomb_phase.go b/prover/protocol/compiler/selfrecursion/lincomb_phase.go index 7c9a6dbc0d8..ea60c55bb7c 100644 --- a/prover/protocol/compiler/selfrecursion/lincomb_phase.go +++ b/prover/protocol/compiler/selfrecursion/lincomb_phase.go @@ -70,18 +70,18 @@ func (ctx *SelfRecursionCtx) consistencyBetweenYsAndUalpha() { // And let the verifier check that they should be both equal ctx.comp.InsertVerifier( round, - func(run *wizard.VerifierRuntime) error { + func(run wizard.Runtime) error { ys := ctx.Columns.Ys.GetColAssignment(run) alpha := run.GetRandomCoinField(ctx.Coins.Alpha.Name) ysAlpha := smartvectors.EvalCoeff(ys, alpha) uAlphaX := ctx.Accessors.InterpolateUalphaX.GetVal(run) if uAlphaX != ysAlpha { - return fmt.Errorf("ConsistencyBetweenYsAndUalpha did not pass") + return fmt.Errorf("ConsistencyBetweenYsAndUalpha did not pass, ysAlphaX=%v uAlphaX=%v", ysAlpha.String(), uAlphaX.String()) } return nil }, - func(api frontend.API, run *wizard.WizardVerifierCircuit) { + func(api frontend.API, run wizard.GnarkRuntime) { ys := ctx.Columns.Ys.GetColAssignmentGnark(run) alpha := run.GetRandomCoinField(ctx.Coins.Alpha.Name) uAlphaX := ctx.Accessors.InterpolateUalphaX.GetFrontendVariable(api, run) diff --git a/prover/protocol/compiler/selfrecursion/names.go b/prover/protocol/compiler/selfrecursion/names.go index ffb8bcba709..8fcc66112d3 100644 --- a/prover/protocol/compiler/selfrecursion/names.go +++ b/prover/protocol/compiler/selfrecursion/names.go @@ -12,7 +12,8 @@ import ( // Name of the polynomial I(x) func (ctx *SelfRecursionCtx) iName(length int) ifaces.ColID { - return ifaces.ColIDf("PRECOMPUTED_%v_I_%v", ctx.SelfRecursionCnt, length) + name := ifaces.ColIDf("PRECOMPUTED_%v_I_%v", ctx.SelfRecursionCnt, length) + return maybePrefix(ctx, name) } // Name of the aH polynomials @@ -23,98 +24,129 @@ func (ctx *SelfRecursionCtx) ahName(key *ringsis.Key, start, length, maxSize int } subName := ifaces.ColIDf("SISKEY_%v_%v_%v", key.LogTwoBound, key.LogTwoDegree, key.MaxNumFieldHashable()) - return ifaces.ColIDf("%v_%v_%v_%v", subName, start, length, maxSize) + name := ifaces.ColIDf("%v_%v_%v_%v", subName, start, length, maxSize) + return maybePrefix(ctx, name) } // Name of the preimage in limb expanded. nameWhole is the name of the // associated column without limb expansion. func (ctx *SelfRecursionCtx) limbExpandedPreimageName(nameWhole ifaces.ColID) ifaces.ColID { - return ifaces.ColIDf("%v_LIMB_EXPANDED_%v", nameWhole, ctx.SelfRecursionCnt) + name := ifaces.ColIDf("%v_LIMB_EXPANDED_%v", nameWhole, ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name of the UalphaQ column func (ctx *SelfRecursionCtx) uAlphaQName() ifaces.ColID { - return ifaces.ColIDf("SELFRECURSION_U_ALPHA_Q_%v", ctx.SelfRecursionCnt) + name := ifaces.ColIDf("SELFRECURSION_U_ALPHA_Q_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name of the self-recursed inclusion query func (ctx *SelfRecursionCtx) selectQInclusion() ifaces.QueryID { - return ifaces.QueryIDf("SELFRECURSION_SELECT_Q_INCLUSION_%v", ctx.SelfRecursionCnt) + name := ifaces.QueryIDf("SELFRECURSION_SELECT_Q_INCLUSION_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name of the collapse coin func (ctx *SelfRecursionCtx) collapseCoin() coin.Name { - return coin.Namef("SELFRECURSION_COLLAPSE_COIN_%v", ctx.SelfRecursionCnt) + name := coin.Namef("SELFRECURSION_COLLAPSE_COIN_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name for the coeff eval for consistency check between Ualphaq // and the preimage (left-side, over UalphaA) func (ctx *SelfRecursionCtx) constencyUalphaQPreimageLeft() string { - return fmt.Sprintf("SELFRECURSION_CONSISTENCY_UALPHA_PREIMAGE_LEFT_%v", ctx.SelfRecursionCnt) + name := fmt.Sprintf("SELFRECURSION_CONSISTENCY_UALPHA_PREIMAGE_LEFT_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name for the coeff eval for consistency check between Ualphaq // and the preimage (right-side, over preiimages) func (ctx *SelfRecursionCtx) constencyUalphaQPreimageRight() string { - return fmt.Sprintf("SELFRECURSION_CONSISTENCY_UALPHA_PREIMAGE_RIGHT_%v", ctx.SelfRecursionCnt) + name := fmt.Sprintf("SELFRECURSION_CONSISTENCY_UALPHA_PREIMAGE_RIGHT_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name of Edual func (ctx *SelfRecursionCtx) eDual() ifaces.ColID { - return ifaces.ColIDf("SELFRECURSION_E_DUAL_%v", ctx.SelfRecursionCnt) + name := ifaces.ColIDf("SELFRECURSION_E_DUAL_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name for the fold coin func (ctx *SelfRecursionCtx) foldCoinName() coin.Name { - return coin.Namef("SELFRECURSION_FOLD_COIN_%v", ctx.SelfRecursionCnt) + name := coin.Namef("SELFRECURSION_FOLD_COIN_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name of the inner product between between PreimageCollapseFold and ACollapseFold func (ctx *SelfRecursionCtx) preimagesAndAmergeIP() ifaces.QueryID { - return ifaces.QueryIDf("SELFRECURSION_PREIMAGE_A_IP_%v", ctx.SelfRecursionCnt) + name := ifaces.QueryIDf("SELFRECURSION_PREIMAGE_A_IP_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name of the interpolation context for Ualpha func (ctx *SelfRecursionCtx) interpolateUAlphaX() string { - return fmt.Sprintf("SELFRECURSION_INTERPOLATE_UALPHA_X_%v", ctx.SelfRecursionCnt) + name := fmt.Sprintf("SELFRECURSION_INTERPOLATE_UALPHA_X_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name of the concatenation of the DhQs func (ctx *SelfRecursionCtx) concatenatedDhQ() ifaces.ColID { - return ifaces.ColIDf("SELFRECURSION_CONCAT_DHQ_%v", ctx.SelfRecursionCnt) + name := ifaces.ColIDf("SELFRECURSION_CONCAT_DHQ_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name of the MerkleLeaves func (ctx *SelfRecursionCtx) merkleLeavesName() ifaces.ColID { - return ifaces.ColIDf("SELFRECURSION_MERKLE_LEAVES_%v", ctx.SelfRecursionCnt) + name := ifaces.ColIDf("SELFRECURSION_MERKLE_LEAVES_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name of the MerklePositions func (ctx *SelfRecursionCtx) merklePositionssName() ifaces.ColID { - return ifaces.ColIDf("SELFRECURSION_MERKLE_POSITIONS_%v", ctx.SelfRecursionCnt) + name := ifaces.ColIDf("SELFRECURSION_MERKLE_POSITIONS_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name of the MerkleRoots func (ctx *SelfRecursionCtx) merkleRootsName() ifaces.ColID { - return ifaces.ColIDf("SELFRECURSION_MERKLE_ROOTS_%v", ctx.SelfRecursionCnt) + name := ifaces.ColIDf("SELFRECURSION_MERKLE_ROOTS_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name of the Merkle proof verification func (ctx *SelfRecursionCtx) merkleProofVerificationName() string { - return fmt.Sprintf("SELFRECURSION_MERKLE_%v", ctx.SelfRecursionCnt) + name := fmt.Sprintf("SELFRECURSION_MERKLE_%v", ctx.SelfRecursionCnt) + return maybePrefix(ctx, name) } // Name of the collapsed key func (ctx *SelfRecursionCtx) aCollapsedName() string { - return fmt.Sprintf("SELFRECURSION_ACOLLAPSE_%v", ctx.comp.SelfRecursionCount) + name := fmt.Sprintf("SELFRECURSION_ACOLLAPSE_%v", ctx.comp.SelfRecursionCount) + return maybePrefix(ctx, name) } // Name of the collapsed key func (ctx *SelfRecursionCtx) rootHasGlue() ifaces.QueryID { - return ifaces.QueryIDf("SELFRECURSION_ROOT_HASH_GLUE_%v", ctx.comp.SelfRecursionCount) + name := ifaces.QueryIDf("SELFRECURSION_ROOT_HASH_GLUE_%v", ctx.comp.SelfRecursionCount) + return maybePrefix(ctx, name) } // Positions glue func (ctx *SelfRecursionCtx) positionGlue() ifaces.QueryID { - return ifaces.QueryIDf("SELFRECURSION_POSITION_GLUE_%v", ctx.comp.SelfRecursionCount) + name := ifaces.QueryIDf("SELFRECURSION_POSITION_GLUE_%v", ctx.comp.SelfRecursionCount) + return maybePrefix(ctx, name) +} + +// linearHashVerificatioName returns the name passed to the wizard helper building the +// linear hash verifier. +func (ctx *SelfRecursionCtx) linearHashVerificationName() string { + name := fmt.Sprintf("SELFRECURSION_LINEAR_HASH_VERIFICATION_%v", ctx.comp.SelfRecursionCount) + return maybePrefix(ctx, name) +} + +// maybePrefix adds the prefix if defined in the context +func maybePrefix[T ~string](ctx *SelfRecursionCtx, name T) T { + return T(ctx.NamePrefix+".") + name } diff --git a/prover/protocol/compiler/selfrecursion/precomputations.go b/prover/protocol/compiler/selfrecursion/precomputations.go index e1fd9589734..159aba4d3be 100644 --- a/prover/protocol/compiler/selfrecursion/precomputations.go +++ b/prover/protocol/compiler/selfrecursion/precomputations.go @@ -95,7 +95,7 @@ func (ctx *SelfRecursionCtx) registersAh() { // associated Dh should be nil. That happens when the examinated round // is a "dry" round or when it has been self-recursed already. if (len(comsInRoundsI) == 0) != (ctx.Columns.Rooth[i] == nil) { - panic("nilness mismatch") + utils.Panic("nilness mismatch for round=%v #coms-in-round=%v vs root-is-nil=%v", i, len(comsInRoundsI), ctx.Columns.Rooth[i] == nil) } // Check if there is no rows to commit diff --git a/prover/protocol/compiler/selfrecursion/selfrecursion.go b/prover/protocol/compiler/selfrecursion/selfrecursion.go index edb1189e3fb..bc38923b527 100644 --- a/prover/protocol/compiler/selfrecursion/selfrecursion.go +++ b/prover/protocol/compiler/selfrecursion/selfrecursion.go @@ -1,6 +1,7 @@ package selfrecursion import ( + "github.com/consensys/linea-monorepo/prover/protocol/compiler/vortex" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/sirupsen/logrus" ) @@ -19,3 +20,12 @@ func SelfRecurse(comp *wizard.CompiledIOP) { // Update the self-recursion counter comp.SelfRecursionCount++ } + +// RecurseOverCustomCtx applies the same compilation steps as [SelfRecurse] +// over a specified vortex compilation context. +func RecurseOverCustomCtx(comp *wizard.CompiledIOP, vortexCtx *vortex.Ctx, prefix string) { + ctx := NewRecursionCtx(comp, vortexCtx, prefix) + ctx.Precomputations() + ctx.RowLinearCombinationPhase() + ctx.ColumnOpeningPhase() +} diff --git a/prover/protocol/compiler/splitter/splitter.go b/prover/protocol/compiler/splitter/splitter.go index 062fb8ebc4c..c44de922d71 100644 --- a/prover/protocol/compiler/splitter/splitter.go +++ b/prover/protocol/compiler/splitter/splitter.go @@ -250,13 +250,13 @@ func (ctx splitterCtx) compileGlobal(comp *wizard.CompiledIOP, q query.GlobalCon } // Requires the verifier to verify the query itself - comp.InsertVerifier(round, func(vr *wizard.VerifierRuntime) error { + comp.InsertVerifier(round, func(vr wizard.Runtime) error { err := q.Check(vr) if err != nil { return fmt.Errorf("failure for query %v, here is why %v", q.ID, err) } return nil - }, func(api frontend.API, wvc *wizard.WizardVerifierCircuit) { + }, func(api frontend.API, wvc wizard.GnarkRuntime) { q.CheckGnark(api, wvc) }) @@ -416,13 +416,13 @@ func (ctx splitterCtx) compileLocal(comp *wizard.CompiledIOP, q query.LocalConst } // Requires the verifier to verify the query itself - comp.InsertVerifier(round, func(vr *wizard.VerifierRuntime) error { + comp.InsertVerifier(round, func(vr wizard.Runtime) error { err := q.Check(vr) if err != nil { return fmt.Errorf("failure for query %v, here is why %v", q.ID, err) } return nil - }, func(api frontend.API, wvc *wizard.WizardVerifierCircuit) { + }, func(api frontend.API, wvc wizard.GnarkRuntime) { q.CheckGnark(api, wvc) }) @@ -488,9 +488,9 @@ func (ctx splitterCtx) compileLocalOpening(comp *wizard.CompiledIOP, q query.Loc verifiercol.AssertIsPublicCol(comp, q.Pol) // Requires the verifier to verify the query itself - comp.InsertVerifier(round, func(vr *wizard.VerifierRuntime) error { + comp.InsertVerifier(round, func(vr wizard.Runtime) error { return q.Check(vr) - }, func(api frontend.API, wvc *wizard.WizardVerifierCircuit) { + }, func(api frontend.API, wvc wizard.GnarkRuntime) { q.CheckGnark(api, wvc) }) @@ -522,7 +522,7 @@ func (ctx splitterCtx) compileLocalOpening(comp *wizard.CompiledIOP, q query.Loc }) // The verifier ensures that the old and new queries have the same assignement - comp.InsertVerifier(round, func(run *wizard.VerifierRuntime) error { + comp.InsertVerifier(round, func(run wizard.Runtime) error { oldParams := run.GetLocalPointEvalParams(q.ID) newParams := run.GetLocalPointEvalParams(newQName) @@ -531,7 +531,7 @@ func (ctx splitterCtx) compileLocalOpening(comp *wizard.CompiledIOP, q query.Loc } return nil - }, func(api frontend.API, run *wizard.WizardVerifierCircuit) { + }, func(api frontend.API, run wizard.GnarkRuntime) { oldParams := run.GetLocalPointEvalParams(q.ID) newParams := run.GetLocalPointEvalParams(newQName) api.AssertIsEqual(oldParams.Y, newParams.Y) diff --git a/prover/protocol/compiler/splitter/sticker/sticker.go b/prover/protocol/compiler/splitter/sticker/sticker.go index e90ec7ff48d..b03ccaf2075 100644 --- a/prover/protocol/compiler/splitter/sticker/sticker.go +++ b/prover/protocol/compiler/splitter/sticker/sticker.go @@ -438,7 +438,7 @@ func (ctx *stickContext) compileFixedEvaluation() { }) // The verifier ensures that the old and new queries have the same assignement - ctx.comp.InsertVerifier(round, func(run *wizard.VerifierRuntime) error { + ctx.comp.InsertVerifier(round, func(run wizard.Runtime) error { oldParams := run.GetLocalPointEvalParams(q.ID) newParams := run.GetLocalPointEvalParams(queryName(q.ID)) @@ -447,7 +447,7 @@ func (ctx *stickContext) compileFixedEvaluation() { } return nil - }, func(api frontend.API, run *wizard.WizardVerifierCircuit) { + }, func(api frontend.API, run wizard.GnarkRuntime) { oldParams := run.GetLocalPointEvalParams(q.ID) newParams := run.GetLocalPointEvalParams(queryName(q.ID)) api.AssertIsEqual(oldParams.Y, newParams.Y) diff --git a/prover/protocol/compiler/stitch_split/stitcher/constraints.go b/prover/protocol/compiler/stitch_split/stitcher/constraints.go index 7d269a2e7e4..25d778a3954 100644 --- a/prover/protocol/compiler/stitch_split/stitcher/constraints.go +++ b/prover/protocol/compiler/stitch_split/stitcher/constraints.go @@ -248,9 +248,9 @@ func insertVerifier( ) { // Requires the verifier to verify the query itself - comp.InsertVerifier(round, func(vr *wizard.VerifierRuntime) error { + comp.InsertVerifier(round, func(vr wizard.Runtime) error { return q.Check(vr) - }, func(api frontend.API, wvc *wizard.WizardVerifierCircuit) { + }, func(api frontend.API, wvc wizard.GnarkRuntime) { q.CheckGnark(api, wvc) }) diff --git a/prover/protocol/compiler/univariates/local_opening_point.go b/prover/protocol/compiler/univariates/local_opening_point.go index e3896cdd0c3..c055f62a697 100644 --- a/prover/protocol/compiler/univariates/local_opening_point.go +++ b/prover/protocol/compiler/univariates/local_opening_point.go @@ -96,7 +96,7 @@ func (ctx *localOpeningCtx) prover(assi *wizard.ProverRuntime) { assi.AssignUnivariate(ctx.fixedToVariable(), field.One(), ys...) } -func (ctx localOpeningCtx) verifier(assi *wizard.VerifierRuntime) error { +func (ctx localOpeningCtx) verifier(assi wizard.Runtime) error { ys := []field.Element{} // Collect the evaluation from the assigned compiled queries @@ -131,7 +131,7 @@ func (ctx localOpeningCtx) verifier(assi *wizard.VerifierRuntime) error { return nil } -func (ctx localOpeningCtx) gnarkVerifier(api frontend.API, c *wizard.WizardVerifierCircuit) { +func (ctx localOpeningCtx) gnarkVerifier(api frontend.API, c wizard.GnarkRuntime) { ys := []frontend.Variable{} // Collect the evaluation from the assigned compiled queries diff --git a/prover/protocol/compiler/univariates/multi_to_single_point.go b/prover/protocol/compiler/univariates/multi_to_single_point.go index 27020fbe82b..f9c560cb302 100644 --- a/prover/protocol/compiler/univariates/multi_to_single_point.go +++ b/prover/protocol/compiler/univariates/multi_to_single_point.go @@ -2,12 +2,13 @@ package univariates import ( "fmt" - ppool "github.com/consensys/linea-monorepo/prover/utils/parallel/pool" "math/big" "reflect" "runtime" "sync" + ppool "github.com/consensys/linea-monorepo/prover/utils/parallel/pool" + "github.com/consensys/gnark/frontend" "github.com/sirupsen/logrus" @@ -396,12 +397,12 @@ func (ctx mptsCtx) claimEvaluation(run *wizard.ProverRuntime) { } // verifier of the evaluation -func (ctx mptsCtx) verifier(run *wizard.VerifierRuntime) error { +func (ctx mptsCtx) verifier(run wizard.Runtime) error { ys, hs := ctx.getYsHs( run.GetUnivariateParams, func(qName ifaces.QueryID) query.UnivariateEval { - return run.Spec.QueriesParams.Data(qName).(query.UnivariateEval) + return run.GetQuery(qName).(query.UnivariateEval) }, ) @@ -513,7 +514,7 @@ func (ctx mptsCtx) verifier(run *wizard.VerifierRuntime) error { Gnark function generating constraints to mirror the verification of the evaluation step. */ -func (ctx mptsCtx) gnarkVerify(api frontend.API, c *wizard.WizardVerifierCircuit) { +func (ctx mptsCtx) gnarkVerify(api frontend.API, c wizard.GnarkRuntime) { logrus.Infof("Start verifying MPTS reduction") @@ -696,7 +697,7 @@ func getLagrangesPolys(domain []field.Element) (lagranges [][]field.Element) { // Mirrrors `getYsHs` to build a gnark circuit func (ctx mptsCtx) getYsHsGnark( - c *wizard.WizardVerifierCircuit, + c wizard.GnarkRuntime, ) ( ys map[ifaces.ColID][]frontend.Variable, hs []frontend.Variable, diff --git a/prover/protocol/compiler/univariates/naturalize.go b/prover/protocol/compiler/univariates/naturalize.go index 658591a8ee3..abc6e06e375 100644 --- a/prover/protocol/compiler/univariates/naturalize.go +++ b/prover/protocol/compiler/univariates/naturalize.go @@ -275,7 +275,7 @@ func (ctx *naturalizationCtx) prove(run *wizard.ProverRuntime) { } } -func (ctx naturalizationCtx) Verify(run *wizard.VerifierRuntime) error { +func (ctx naturalizationCtx) Verify(run wizard.Runtime) error { // Get the original query originalQuery := run.GetUnivariateEval(ctx.q.QueryID) @@ -346,7 +346,7 @@ func (ctx naturalizationCtx) Verify(run *wizard.VerifierRuntime) error { } -func (ctx naturalizationCtx) GnarkVerify(api frontend.API, c *wizard.WizardVerifierCircuit) { +func (ctx naturalizationCtx) GnarkVerify(api frontend.API, c wizard.GnarkRuntime) { logrus.Tracef("verifying naturalization") diff --git a/prover/protocol/compiler/vortex/compiler.go b/prover/protocol/compiler/vortex/compiler.go index 7aa626bd381..9f690d7cba8 100644 --- a/prover/protocol/compiler/vortex/compiler.go +++ b/prover/protocol/compiler/vortex/compiler.go @@ -119,13 +119,17 @@ type Ctx struct { VortexParams *vortex.Params SisParams *ringsis.Params // Optional parameter - numOpenedCol int + NumOpenedCol int // By rounds commitments : if a round is dried we make an empty sublist. // Inversely, for the `driedByRounds` which track the dried commitments. CommitmentsByRounds collection.VecVec[ifaces.ColID] DriedByRounds collection.VecVec[ifaces.ColID] + // RunStateNamePrefix is used to prefix some of the names of components of the + // compilation context. Mainly state objects. + RunStateNamePrefix string + // Items created by Vortex, includes the proof message and the coins Items struct { // List of items used only if the CommitPrecomputed flag is set @@ -135,17 +139,13 @@ type Ctx struct { PrecomputedColums []ifaces.Column // Merkle Root of the precomputeds columns MerkleRoot ifaces.Column - // List of the column hashes for the precomputed columns - Dh ifaces.Column // Committed matrix (rs encoded) of the precomputed columns CommittedMatrix vortex.EncodedMatrix // Tree in case of Merkle mode - tree *smt.Tree + Tree *smt.Tree // colHashes used in self recursion DhWithMerkle []field.Element } - // (not used in the Merkle proof version) - Dh []ifaces.Column // Alpha is a random combination linear coin Alpha coin.Info // Linear combination of the row-encoded matrix @@ -185,12 +185,10 @@ func newCtx(comp *wizard.CompiledIOP, univQ query.UnivariateEval, blowUpFactor i Precomputeds struct { PrecomputedColums []ifaces.Column MerkleRoot ifaces.Column - Dh ifaces.Column CommittedMatrix vortex.EncodedMatrix - tree *smt.Tree + Tree *smt.Tree DhWithMerkle []field.Element } - Dh []ifaces.Column Alpha coin.Info Ualpha ifaces.Column Q coin.Info @@ -437,8 +435,8 @@ func (ctx *Ctx) NbColsToOpen() int { // If the context was created with the relevant option, // we return the instructed value - if ctx.numOpenedCol > 0 { - return ctx.numOpenedCol + if ctx.NumOpenedCol > 0 { + return ctx.NumOpenedCol } if !utils.IsPowerOfTwo(ctx.BlowUpFactor) { @@ -530,7 +528,9 @@ func (ctx *Ctx) NumEncodedCols() int { return res } -// Create a method to decide when to commit to the precomputed +// IsCommitToPrecomputed returns true if the current compilation step +// commits to the precomputed columns. This is detected by checking if +// the number of precomputed columns is greater than the dry treshold. func (ctx *Ctx) IsCommitToPrecomputed() bool { return len(ctx.Items.Precomputeds.PrecomputedColums) > ctx.DryTreshold } @@ -710,7 +710,7 @@ func (ctx *Ctx) commitPrecomputeds() { committedMatrix, tree, colHashes := ctx.VortexParams.CommitMerkle(pols) ctx.Items.Precomputeds.DhWithMerkle = colHashes ctx.Items.Precomputeds.CommittedMatrix = committedMatrix - ctx.Items.Precomputeds.tree = tree + ctx.Items.Precomputeds.Tree = tree // And assign the 1-sized column to contain the root var root field.Element diff --git a/prover/protocol/compiler/vortex/gnark_verifier.go b/prover/protocol/compiler/vortex/gnark_verifier.go index 430a1560408..3fd7795ae8d 100644 --- a/prover/protocol/compiler/vortex/gnark_verifier.go +++ b/prover/protocol/compiler/vortex/gnark_verifier.go @@ -14,7 +14,7 @@ import ( "github.com/consensys/linea-monorepo/prover/utils" ) -func (ctx *Ctx) GnarkVerify(api frontend.API, vr *wizard.WizardVerifierCircuit) { +func (ctx *Ctx) GnarkVerify(api frontend.API, vr wizard.GnarkRuntime) { // The skip verification flag may be on, if the current vortex // context get self-recursed. In this case, the verifier does @@ -61,7 +61,7 @@ func (ctx *Ctx) GnarkVerify(api frontend.API, vr *wizard.WizardVerifierCircuit) // function that will defer the hashing to gkr factoryHasherFunc := func(_ frontend.API) (hash.FieldHasher, error) { - h := vr.HasherFactory.NewHasher() + h := vr.GetHasherFactory().NewHasher() return h, nil } @@ -91,7 +91,7 @@ func (ctx *Ctx) GnarkVerify(api frontend.API, vr *wizard.WizardVerifierCircuit) } // returns the Ys as a vector -func (ctx *Ctx) gnarkGetYs(_ frontend.API, vr *wizard.WizardVerifierCircuit) (ys [][]frontend.Variable) { +func (ctx *Ctx) gnarkGetYs(_ frontend.API, vr wizard.GnarkRuntime) (ys [][]frontend.Variable) { query := ctx.Query params := vr.GetUnivariateParams(ctx.Query.QueryID) @@ -161,7 +161,7 @@ func (ctx *Ctx) gnarkGetYs(_ frontend.API, vr *wizard.WizardVerifierCircuit) (ys // Returns the opened columns from the messages. The returned columns are // split "by-commitment-round". -func (ctx *Ctx) GnarkRecoverSelectedColumns(api frontend.API, vr *wizard.WizardVerifierCircuit) [][][]frontend.Variable { +func (ctx *Ctx) GnarkRecoverSelectedColumns(api frontend.API, vr wizard.GnarkRuntime) [][][]frontend.Variable { // Collect the columns : first extract the full columns // Bear in mind that the prover messages are zero-padded @@ -213,7 +213,7 @@ func (ctx *Ctx) GnarkRecoverSelectedColumns(api frontend.API, vr *wizard.WizardV } // Evaluates explicitly the public polynomials (proof, vk, public inputs) -func (ctx *Ctx) gnarkExplicitPublicEvaluation(api frontend.API, vr *wizard.WizardVerifierCircuit) { +func (ctx *Ctx) gnarkExplicitPublicEvaluation(api frontend.API, vr wizard.GnarkRuntime) { params := vr.GetUnivariateParams(ctx.Query.QueryID) diff --git a/prover/protocol/compiler/vortex/names.go b/prover/protocol/compiler/vortex/names.go index 7536bcf2970..ced90f744b7 100644 --- a/prover/protocol/compiler/vortex/names.go +++ b/prover/protocol/compiler/vortex/names.go @@ -32,14 +32,32 @@ func (ctx *Ctx) CommitmentName(round int) ifaces.ColID { return ifaces.ColIDf("VORTEX_%v_COMMITMENT_ROUND_%v", ctx.SelfRecursionCount, round) } +// SisHashName returns a preformatted message representing the Sis hash digests +// for each round that we store in the state. +func (ctx *Ctx) SisHashName(round int) string { + name := fmt.Sprintf("VORTEX_%v_SIS_HASH_%v", ctx.SelfRecursionCount, round) + if len(ctx.RunStateNamePrefix) == 0 { + return name + } + return ctx.RunStateNamePrefix + "." + name +} + // returns the name of a prover state for a given round of Vortex func (ctx *Ctx) VortexProverStateName(round int) string { - return fmt.Sprintf("VORTEX_%v_PROVER_STATE_%v", ctx.SelfRecursionCount, round) + name := fmt.Sprintf("VORTEX_%v_PROVER_STATE_%v", ctx.SelfRecursionCount, round) + if len(ctx.RunStateNamePrefix) == 0 { + return name + } + return ctx.RunStateNamePrefix + "." + name } // returns the name of a prover state for a given round of Vortex func (ctx *Ctx) MerkleTreeName(round int) string { - return fmt.Sprintf("VORTEX_%v_MERKLE_TREE_%v", ctx.SelfRecursionCount, round) + name := fmt.Sprintf("VORTEX_%v_MERKLE_TREE_%v", ctx.SelfRecursionCount, round) + if len(ctx.RunStateNamePrefix) == 0 { + return name + } + return ctx.RunStateNamePrefix + "." + name } // returns the name of the vector containing all the Merkle proofs diff --git a/prover/protocol/compiler/vortex/option.go b/prover/protocol/compiler/vortex/option.go index 668e8ec758e..be66d3cbd3a 100644 --- a/prover/protocol/compiler/vortex/option.go +++ b/prover/protocol/compiler/vortex/option.go @@ -11,7 +11,7 @@ type VortexOp func(ctx *Ctx) // not be used in production) func ForceNumOpenedColumns(nbCol int) VortexOp { return func(ctx *Ctx) { - ctx.numOpenedCol = nbCol + ctx.NumOpenedCol = nbCol } } @@ -36,3 +36,12 @@ func ReplaceSisByMimc() VortexOp { ctx.SisParams = nil } } + +// PremarkAsSelfRecursed marks the ctx as selfrecursed. This is useful +// toward conglomerating the receiver comp but is not needed for +// self-recursion or full-recursion. +func PremarkAsSelfRecursed() VortexOp { + return func(ctx *Ctx) { + ctx.IsSelfrecursed = true + } +} diff --git a/prover/protocol/compiler/vortex/prover.go b/prover/protocol/compiler/vortex/prover.go index bad0c99a18d..4b81787d670 100644 --- a/prover/protocol/compiler/vortex/prover.go +++ b/prover/protocol/compiler/vortex/prover.go @@ -33,7 +33,7 @@ func (ctx *Ctx) AssignColumn(round int) func(*wizard.ProverRuntime) { // Only to be read by the self-recursion compiler. if ctx.IsSelfrecursed { - pr.State.InsertNew(string(ctx.CommitmentName(round)), sisDigest) + pr.State.InsertNew(ctx.SisHashName(round), sisDigest) } // And assign the 1-sized column to contain the root @@ -67,11 +67,44 @@ func (ctx *Ctx) ComputeLinearComb(pr *wizard.ProverRuntime) { } // And get the randomness - randomCoinLC := pr.GetRandomCoinField(ctx.LinCombRandCoinName()) + randomCoinLC := pr.GetRandomCoinField(ctx.Items.Alpha.Name) // and compute and assign the random linear combination of the rows proof := ctx.VortexParams.InitOpeningWithLC(committedSV, randomCoinLC) - pr.AssignColumn(ctx.LinCombName(), proof.LinearCombination) + pr.AssignColumn(ctx.Items.Ualpha.GetColID(), proof.LinearCombination) +} + +// ComputeLinearCombFromRsMatrix is the same as ComputeLinearComb but uses +// the RS encoded matrix instead of using the basic one. It is slower than +// the later but is recommended. +func (ctx *Ctx) ComputeLinearCombFromRsMatrix(pr *wizard.ProverRuntime) { + + committedSV := []smartvectors.SmartVector{} + + // Add the precomputed columns to commitedSV if IsCommitToPrecomputed is true + if ctx.IsCommitToPrecomputed() { + committedSV = append(committedSV, ctx.Items.Precomputeds.CommittedMatrix...) + } + + // Collect all the committed polynomials : round by round + for round := 0; round <= ctx.MaxCommittedRound; round++ { + // There are not included in the commitments so there + // is no need to compute their linear combination. + if ctx.isDry(round) { + continue + } + + committedMatrix := pr.State.MustGet(ctx.VortexProverStateName(round)).(vortex.EncodedMatrix) + committedSV = append(committedSV, committedMatrix...) + } + + // And get the randomness + randomCoinLC := pr.GetRandomCoinField(ctx.Items.Alpha.Name) + + // and compute and assign the random linear combination of the rows + proof := ctx.VortexParams.InitOpeningFromAlreadyEncodedLC(committedSV, randomCoinLC) + + pr.AssignColumn(ctx.Items.Ualpha.GetColID(), proof.LinearCombination) } // Prover steps of Vortex where he opens the columns selected by the verifier @@ -85,7 +118,7 @@ func (ctx *Ctx) OpenSelectedColumns(pr *wizard.ProverRuntime) { // Append the precomputed committedMatrices and trees when IsCommitToPrecomputed is true if ctx.IsCommitToPrecomputed() { committedMatrices = append(committedMatrices, ctx.Items.Precomputeds.CommittedMatrix) - trees = append(trees, ctx.Items.Precomputeds.tree) + trees = append(trees, ctx.Items.Precomputeds.Tree) } for round := 0; round <= ctx.MaxCommittedRound; round++ { @@ -107,7 +140,7 @@ func (ctx *Ctx) OpenSelectedColumns(pr *wizard.ProverRuntime) { trees = append(trees, tree) } - entryList := pr.GetRandomCoinIntegerVec(ctx.RandColSelectionName()) + entryList := pr.GetRandomCoinIntegerVec(ctx.Items.Q.Name) proof := vortex.OpeningProof{} // Merkle mode only: @@ -131,11 +164,11 @@ func (ctx *Ctx) OpenSelectedColumns(pr *wizard.ProverRuntime) { assignable = smartvectors.RightZeroPadded(fullCol, utils.NextPowerOfTwo(len(fullCol))) } - pr.AssignColumn(ctx.SelectedColName(j), assignable) + pr.AssignColumn(ctx.Items.OpenedColumns[j].GetColID(), assignable) } packedMProofs := ctx.packMerkleProofs(proof.MerkleProofs) - pr.AssignColumn(ctx.MerkleProofName(), packedMProofs) + pr.AssignColumn(ctx.Items.MerkleProofs.GetColID(), packedMProofs) } // returns true if the round is dry (i.e, there is nothing to commit to) diff --git a/prover/protocol/compiler/vortex/verifier.go b/prover/protocol/compiler/vortex/verifier.go index 1b27c1779b6..e27e7d48f75 100644 --- a/prover/protocol/compiler/vortex/verifier.go +++ b/prover/protocol/compiler/vortex/verifier.go @@ -13,7 +13,7 @@ import ( "github.com/consensys/linea-monorepo/prover/utils/types" ) -func (ctx *Ctx) Verify(vr *wizard.VerifierRuntime) error { +func (ctx *Ctx) Verify(vr wizard.Runtime) error { // The skip verification flag may be on, if the current vortex // context get self-recursed. In this case, the verifier does @@ -77,7 +77,7 @@ func (ctx *Ctx) getNbCommittedRows(round int) int { } // returns the Ys as a vector -func (ctx *Ctx) getYs(vr *wizard.VerifierRuntime) (ys [][]field.Element) { +func (ctx *Ctx) getYs(vr wizard.Runtime) (ys [][]field.Element) { query := ctx.Query params := vr.GetUnivariateParams(ctx.Query.QueryID) @@ -130,7 +130,7 @@ func (ctx *Ctx) getYs(vr *wizard.VerifierRuntime) (ys [][]field.Element) { // Returns the opened columns from the messages. The returned columns are // split "by-commitment-round". -func (ctx *Ctx) RecoverSelectedColumns(vr *wizard.VerifierRuntime, entryList []int) [][][]field.Element { +func (ctx *Ctx) RecoverSelectedColumns(vr wizard.Runtime, entryList []int) [][][]field.Element { // Collect the columns : first extract the full columns // Bear in mind that the prover messages are zero-padded @@ -184,7 +184,7 @@ func (ctx *Ctx) RecoverSelectedColumns(vr *wizard.VerifierRuntime, entryList []i } // Evaluates explicitly the public polynomials (proof, vk, public inputs) -func (ctx *Ctx) explicitPublicEvaluation(vr *wizard.VerifierRuntime) error { +func (ctx *Ctx) explicitPublicEvaluation(vr wizard.Runtime) error { params := vr.GetUnivariateParams(ctx.Query.QueryID) diff --git a/prover/protocol/dedicated/functionals/fold.go b/prover/protocol/dedicated/functionals/fold.go index a78df5436bd..a7b572a4f44 100644 --- a/prover/protocol/dedicated/functionals/fold.go +++ b/prover/protocol/dedicated/functionals/fold.go @@ -50,12 +50,12 @@ func Fold(comp *wizard.CompiledIOP, h ifaces.Column, x ifaces.Accessor, innerDeg verRound := utils.Max(outerCoinAcc.Round(), foldedEvalAcc.Round()) // Check that the two evaluations yield the same result - comp.InsertVerifier(verRound, func(a *wizard.VerifierRuntime) error { + comp.InsertVerifier(verRound, func(a wizard.Runtime) error { if foldedEvalAcc.GetVal(a) != hEvalAcc.GetVal(a) { return fmt.Errorf("verifier of folding failed %v", foldedName) } return nil - }, func(api frontend.API, wvc *wizard.WizardVerifierCircuit) { + }, func(api frontend.API, wvc wizard.GnarkRuntime) { c := foldedEvalAcc.GetFrontendVariable(api, wvc) c_ := hEvalAcc.GetFrontendVariable(api, wvc) api.AssertIsEqual(c, c_) diff --git a/prover/protocol/dedicated/functionals/foldouter.go b/prover/protocol/dedicated/functionals/foldouter.go index 98340869ebc..c4fcaa53314 100644 --- a/prover/protocol/dedicated/functionals/foldouter.go +++ b/prover/protocol/dedicated/functionals/foldouter.go @@ -53,12 +53,12 @@ func FoldOuter(comp *wizard.CompiledIOP, h ifaces.Column, x ifaces.Accessor, out verRound := utils.Max(innerCoinAcc.Round(), foldedEvalAcc.Round()) // Check that the two evaluations yield the same result - comp.InsertVerifier(verRound, func(run *wizard.VerifierRuntime) error { + comp.InsertVerifier(verRound, func(run wizard.Runtime) error { if foldedEvalAcc.GetVal(run) != hEvalAcc.GetVal(run) { return fmt.Errorf("verifier of folding failed %v", foldedName) } return nil - }, func(api frontend.API, run *wizard.WizardVerifierCircuit) { + }, func(api frontend.API, run wizard.GnarkRuntime) { c := foldedEvalAcc.GetFrontendVariable(api, run) c_ := hEvalAcc.GetFrontendVariable(api, run) api.AssertIsEqual(c, c_) diff --git a/prover/protocol/dedicated/mimc/linear_hash.go b/prover/protocol/dedicated/mimc/linear_hash.go index 24b01abd3b5..3d1d75cfd8b 100644 --- a/prover/protocol/dedicated/mimc/linear_hash.go +++ b/prover/protocol/dedicated/mimc/linear_hash.go @@ -24,7 +24,7 @@ type linearHashCtx struct { // The compiled IOP comp *wizard.CompiledIOP - // Names of the "data" columns + name string // Output column, which containing the result each // individual hash. @@ -60,6 +60,7 @@ Check a linear hashby chunk of columns */ func CheckLinearHash( comp *wizard.CompiledIOP, + name string, tohash ifaces.Column, period int, numHash int, expectedHashes ifaces.Column, @@ -68,6 +69,7 @@ func CheckLinearHash( // Initialize the context ctx := linearHashCtx{ comp: comp, + name: name, ToHash: tohash, Period: period, NumHash: numHash, @@ -83,7 +85,7 @@ func CheckLinearHash( if ctx.IsFullyActive { selector.CheckSubsample( comp, - prefixWithLinearHash(comp, "RES_EXTRACTION"), + prefixWithLinearHash(comp, name, "RES_EXTRACTION"), []ifaces.Column{ctx.NewStateClean}, []ifaces.Column{ctx.ExpectedHash}, period-1, @@ -91,7 +93,7 @@ func CheckLinearHash( } else { ctx.comp.InsertInclusion( ctx.Round, - ifaces.QueryID(prefixWithLinearHash(comp, "RESULT_CHECK_%v", tohash.GetColID())), + ifaces.QueryID(prefixWithLinearHash(comp, name, "RESULT_CHECK_%v", tohash.GetColID())), []ifaces.Column{ctx.IsEndOfHash, ctx.NewStateClean}, []ifaces.Column{ctx.IsActiveExpected(), ctx.ExpectedHash}, ) @@ -99,9 +101,9 @@ func CheckLinearHash( } -func prefixWithLinearHash(comp *wizard.CompiledIOP, msg string, args ...any) string { - args = append([]any{comp.SelfRecursionCount}, args...) - return fmt.Sprintf("LINEAR_HASH_%v_"+msg, args...) +func prefixWithLinearHash(comp *wizard.CompiledIOP, name, msg string, args ...any) string { + args = append([]any{name, comp.SelfRecursionCount}, args...) + return fmt.Sprintf("%v.LINEAR_HASH_%v_"+msg, args...) } // Declares assign and constraints the columns OldStates and NewStates @@ -110,19 +112,19 @@ func (ctx *linearHashCtx) HashingCols() { // Registers the old states columns ctx.OldState = ctx.comp.InsertCommit( ctx.Round, - ifaces.ColID(prefixWithLinearHash(ctx.comp, "OLD_STATE_%v", ctx.ToHash.GetColID())), + ifaces.ColID(prefixWithLinearHash(ctx.comp, ctx.name, "OLD_STATE_%v", ctx.ToHash.GetColID())), ctx.ToHash.Size(), ) ctx.NewState = ctx.comp.InsertCommit( ctx.Round, - ifaces.ColID(prefixWithLinearHash(ctx.comp, "NEW_STATE_%v", ctx.ToHash.GetColID())), + ifaces.ColID(prefixWithLinearHash(ctx.comp, ctx.name, "NEW_STATE_%v", ctx.ToHash.GetColID())), ctx.ToHash.Size(), ) ctx.NewStateClean = ctx.comp.InsertCommit( ctx.Round, - ifaces.ColIDf(prefixWithLinearHash(ctx.comp, "NEW_STATE_CLEAN_%v", ctx.ToHash.GetColID())), + ifaces.ColIDf(prefixWithLinearHash(ctx.comp, ctx.name, "NEW_STATE_CLEAN_%v", ctx.ToHash.GetColID())), ctx.ToHash.Size(), ) @@ -179,7 +181,7 @@ func (ctx *linearHashCtx) HashingCols() { ctx.comp.InsertGlobal( ctx.Round, - ifaces.QueryID(prefixWithLinearHash(ctx.comp, "STATE_PROPAGATION_%v", ctx.ToHash.GetColID())), + ifaces.QueryID(prefixWithLinearHash(ctx.comp, ctx.name, "STATE_PROPAGATION_%v", ctx.ToHash.GetColID())), expr, true, // no bound cancel to also enforce the first value of old state to be zero ) @@ -189,7 +191,7 @@ func (ctx *linearHashCtx) HashingCols() { // ctx.comp.InsertGlobal( ctx.Round, - ifaces.QueryIDf(prefixWithLinearHash(ctx.comp, "CLEAN_NEW_STATE_%v", ctx.ToHash.GetColID())), + ifaces.QueryIDf(prefixWithLinearHash(ctx.comp, ctx.name, "CLEAN_NEW_STATE_%v", ctx.ToHash.GetColID())), ctx.IsActiveVar(). Mul(ifaces.ColumnAsVariable(ctx.NewState)). Sub(ifaces.ColumnAsVariable(ctx.NewStateClean)), @@ -200,7 +202,7 @@ func (ctx *linearHashCtx) HashingCols() { // ctx.comp.InsertMiMC( ctx.Round, - ifaces.QueryID(prefixWithLinearHash(ctx.comp, "BLOCKS_COMPRESSION_%v", ctx.ToHash.GetColID())), + ifaces.QueryID(prefixWithLinearHash(ctx.comp, ctx.name, "BLOCKS_COMPRESSION_%v", ctx.ToHash.GetColID())), ctx.ToHash, ctx.OldState, ctx.NewState, ) @@ -227,7 +229,7 @@ func (ctx *linearHashCtx) IsActiveExpected() ifaces.Column { } ctx.isActiveExpected = ctx.comp.InsertPrecomputed( - ifaces.ColIDf(prefixWithLinearHash(ctx.comp, "IS_ACTIVE_EXPECTED_%v", ctx.ToHash.GetColID())), + ifaces.ColIDf(prefixWithLinearHash(ctx.comp, ctx.name, "IS_ACTIVE_EXPECTED_%v", ctx.ToHash.GetColID())), assignment, ) } @@ -244,7 +246,7 @@ func (ctx *linearHashCtx) IsActiveVar() *symbolic.Expression { // Lazily registers the columns if ctx.IsActiveLarge == nil { ctx.IsActiveLarge = ctx.comp.InsertPrecomputed( - ifaces.ColIDf(prefixWithLinearHash(ctx.comp, "IS_ACTIVE_%v", ctx.ToHash.GetColID())), + ifaces.ColIDf(prefixWithLinearHash(ctx.comp, ctx.name, "IS_ACTIVE_%v", ctx.ToHash.GetColID())), smartvectors.RightZeroPadded( vector.Repeat(field.One(), ctx.NumHash*ctx.Period), ctx.ToHash.Size(), @@ -272,7 +274,7 @@ func (ctx *linearHashCtx) IsEndOfHashVar() *symbolic.Expression { } ctx.IsEndOfHash = ctx.comp.InsertPrecomputed( - ifaces.ColIDf(prefixWithLinearHash(ctx.comp, "IS_END_OF_HASH_%v", ctx.ToHash.GetColID())), + ifaces.ColIDf(prefixWithLinearHash(ctx.comp, ctx.name, "IS_END_OF_HASH_%v", ctx.ToHash.GetColID())), smartvectors.RightZeroPadded(window, ctx.ToHash.Size()), ) } diff --git a/prover/protocol/dedicated/mimc/linear_hash_test.go b/prover/protocol/dedicated/mimc/linear_hash_test.go index 95bb2dc8124..47ab9b31103 100644 --- a/prover/protocol/dedicated/mimc/linear_hash_test.go +++ b/prover/protocol/dedicated/mimc/linear_hash_test.go @@ -38,7 +38,7 @@ func TestLinearHash(t *testing.T) { define := func(b *wizard.Builder) { tohash = b.RegisterCommit("TOHASH", numRowLarge) expectedhash = b.RegisterCommit("HASHED", numRowSmall) - linhash.CheckLinearHash(b.CompiledIOP, tohash, period, numhash, expectedhash) + linhash.CheckLinearHash(b.CompiledIOP, "test", tohash, period, numhash, expectedhash) } prove := func(run *wizard.ProverRuntime) { diff --git a/prover/protocol/dedicated/plonk/alignment.go b/prover/protocol/dedicated/plonk/alignment.go index 2ac362c49c6..5d5cd743761 100644 --- a/prover/protocol/dedicated/plonk/alignment.go +++ b/prover/protocol/dedicated/plonk/alignment.go @@ -390,7 +390,7 @@ type checkActivatorAndMask struct { skipped bool } -func (c *checkActivatorAndMask) Run(run *wizard.VerifierRuntime) error { +func (c *checkActivatorAndMask) Run(run wizard.Runtime) error { for i := range c.circMaskOpenings { var ( localOpening = run.GetLocalPointEvalParams(c.circMaskOpenings[i].ID) @@ -409,7 +409,7 @@ func (c *checkActivatorAndMask) Run(run *wizard.VerifierRuntime) error { return nil } -func (c *checkActivatorAndMask) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { +func (c *checkActivatorAndMask) RunGnark(api frontend.API, run wizard.GnarkRuntime) { for i := range c.circMaskOpenings { var ( valOpened = run.GetLocalPointEvalParams(c.circMaskOpenings[i].ID).Y diff --git a/prover/protocol/dedicated/plonk/compile.go b/prover/protocol/dedicated/plonk/compile.go index 3e9182fbe64..9b429781a6f 100644 --- a/prover/protocol/dedicated/plonk/compile.go +++ b/prover/protocol/dedicated/plonk/compile.go @@ -306,7 +306,7 @@ type checkingActivators struct { var _ wizard.VerifierAction = &checkingActivators{} -func (ca *checkingActivators) Run(run *wizard.VerifierRuntime) error { +func (ca *checkingActivators) Run(run wizard.Runtime) error { for i := range ca.Cols { curr := ca.Cols[i].GetColAssignmentAt(run, 0) @@ -325,7 +325,7 @@ func (ca *checkingActivators) Run(run *wizard.VerifierRuntime) error { return nil } -func (ca *checkingActivators) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { +func (ca *checkingActivators) RunGnark(api frontend.API, run wizard.GnarkRuntime) { for i := range ca.Cols { curr := ca.Cols[i].GetColAssignmentGnarkAt(run, 0) diff --git a/prover/protocol/dedicated/reedsolomon/reedsolomon.go b/prover/protocol/dedicated/reedsolomon/reedsolomon.go index 02f656305f5..e6753c655ae 100644 --- a/prover/protocol/dedicated/reedsolomon/reedsolomon.go +++ b/prover/protocol/dedicated/reedsolomon/reedsolomon.go @@ -60,14 +60,14 @@ func CheckReedSolomon(comp *wizard.CompiledIOP, rate int, h ifaces.Column) { h, ) - comp.InsertVerifier(round+1, func(a *wizard.VerifierRuntime) error { + comp.InsertVerifier(round+1, func(a wizard.Runtime) error { y := coeffCheck.GetVal(a) y_ := evalCheck.GetVal(a) if y != y_ { return fmt.Errorf("reed-solomon check failed - %v is not a codeword", h.GetColID()) } return nil - }, func(api frontend.API, wvc *wizard.WizardVerifierCircuit) { + }, func(api frontend.API, wvc wizard.GnarkRuntime) { y := coeffCheck.GetFrontendVariable(api, wvc) y_ := evalCheck.GetFrontendVariable(api, wvc) api.AssertIsEqual(y, y_) diff --git a/prover/protocol/dedicated/selector/subsample.go b/prover/protocol/dedicated/selector/subsample.go index 2b7dbecad31..1ac6f0bcb10 100644 --- a/prover/protocol/dedicated/selector/subsample.go +++ b/prover/protocol/dedicated/selector/subsample.go @@ -233,7 +233,7 @@ func CheckSubsample(comp *wizard.CompiledIOP, name string, large, small []ifaces comp.InsertVerifier( round+1, - func(run *wizard.VerifierRuntime) error { + func(run wizard.Runtime) error { resAccLast := run.GetLocalPointEvalParams(accLargeLast.ID) expectedResAccLast := run.GetLocalPointEvalParams(accSmallLast.ID) if resAccLast.Y != expectedResAccLast.Y { @@ -241,7 +241,7 @@ func CheckSubsample(comp *wizard.CompiledIOP, name string, large, small []ifaces } return nil }, - func(a frontend.API, run *wizard.WizardVerifierCircuit) { + func(a frontend.API, run wizard.GnarkRuntime) { resAccLast := run.GetLocalPointEvalParams(accLargeLast.ID) expectedResAccLast := run.GetLocalPointEvalParams(accSmallLast.ID) a.AssertIsEqual(resAccLast.Y, expectedResAccLast.Y) diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion.go b/prover/protocol/distributed/compiler/inclusion/inclusion.go index d14f9f5033e..855e25b2790 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/accessors" "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/distributed" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/constants" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" @@ -162,7 +163,7 @@ func GetShareOfLogDerivativeSum(in DistributionInputs) { // declare [query.LogDerivSumParams] as [wizard.PublicInput] moduleComp.PublicInputs = append(moduleComp.PublicInputs, wizard.PublicInput{ - Name: accessors.LOGDERIVSUM_ACCESSOR, + Name: constants.LogDerivativeSumPublicInput, Acc: accessors.NewLogDerivSumAccessor(logDerivQuery), }) diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion_test.go b/prover/protocol/distributed/compiler/inclusion/inclusion_test.go index 654d5589d30..e62939f5892 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion_test.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion_test.go @@ -1,16 +1,18 @@ package inclusion_test import ( + "errors" "testing" "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" logderiv "github.com/consensys/linea-monorepo/prover/protocol/compiler/logderivativesum" "github.com/consensys/linea-monorepo/prover/protocol/distributed" "github.com/consensys/linea-monorepo/prover/protocol/distributed/compiler/inclusion" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/constants" "github.com/consensys/linea-monorepo/prover/protocol/distributed/lpp" md "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" - "github.com/consensys/linea-monorepo/prover/protocol/distributed/xcomp" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/stretchr/testify/require" @@ -25,7 +27,7 @@ func TestSeedGeneration(t *testing.T) { ) var ( - allVerfiers = []*wizard.VerifierRuntime{} + allVerfiers = []wizard.Runtime{} ) //initialComp define := func(b *wizard.Builder) { @@ -176,7 +178,20 @@ func TestSeedGeneration(t *testing.T) { } // apply the crosse checks over the public inputs. - xComp := xcomp.GetCrossComp(allVerfiers) - wizard.Verify(xComp, wizard.Proof{}) + require.NoError(t, checkConsistency(allVerfiers)) +} + +func checkConsistency(runs []wizard.Runtime) error { + + var res field.Element + for _, run := range runs { + logderiv := run.GetPublicInput(constants.LogDerivativeSumPublicInput) + res.Add(&res, &logderiv) + } + + if !res.IsZero() { + return errors.New("the logderiv sums do not cancel each others") + } + return nil } diff --git a/prover/protocol/distributed/conglomeration/conglomeration.go b/prover/protocol/distributed/conglomeration/conglomeration.go new file mode 100644 index 00000000000..02a016a8ac3 --- /dev/null +++ b/prover/protocol/distributed/conglomeration/conglomeration.go @@ -0,0 +1,303 @@ +package conglomeration + +import ( + "fmt" + + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/selfrecursion" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/vortex" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/constants" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/utils" +) + +// recursionCtx holds compilation context informations about the wizard +type recursionCtx struct { + // A pointer to the compiled-IOP over which the compilation step has run + Translator *compTranslator + Tmpl *wizard.CompiledIOP + // The Vortex compilation context + PcsCtx *vortex.Ctx + PublicInputs []wizard.PublicInput + NonEmptyMerkleRootPositions []int + FirstRound, LastRound int + Columns [][]ifaces.Column + // The columns ignored are the one that are compiled by the vortex context. + // They are added in the target 'comp' but are assigned to zero. Although, + // they do not directly play a role in the protocol anymore, they are still + // referenced by the self-recursion compiler. + ColumnIgnoredPrecomputed []ifaces.Column + ColumnsIgnored [][]ifaces.Column + QueryParams [][]ifaces.Query + VerifierActions [][]wizard.VerifierAction + Coins [][]coin.Info + FsHooks [][]wizard.VerifierAction + LocalOpenings []query.LocalOpening +} + +// ConglomerateDefineFunc returns a function that defines a conglomerate +// comp and a placeholder pointer for the recursion context. On return +// of the function, the pointer points to an empty slice and is populated +// once [wizard.Compile] has been called with def. +// +// To be conglomerable, `tmpl` must be compiled with the [vortex.Compile] +// using the [vortex.PremarkAsSelfRecursed] option and without the +// [vortex.ReplaceByMiMC]. +func ConglomerateDefineFunc(tmpl *wizard.CompiledIOP, maxNumSegment int) (def func(*wizard.Builder), ctxsPlaceHolder *[]*recursionCtx) { + + var ctxs []*recursionCtx + def = func(b *wizard.Builder) { + + comp := b.CompiledIOP + + for id := 0; id < maxNumSegment; id++ { + prefix := fmt.Sprintf("verifier-%v", id) + ctx := initRecursionCtx(prefix, comp, tmpl) + ctx.captureCompPreVortex(tmpl) + ctx.captureVortexCtx(tmpl) + ctxs = append(ctxs, ctx) + } + + // This FS hook has to be defined before we add the pre-vortex verifier + // hooks to ensure that the FS state is properly initialize the verifier + // runtime. + comp.FiatShamirHooks.AppendToInner(0, &SubFsInitialize{Ctxs: ctxs}) + + for round := 0; round <= ctxs[0].LastRound; round++ { + + var ( + hasCoin = len(ctxs[0].Coins[round]) > 0 + hasVAction = len(ctxs[0].VerifierActions[round]) > 0 + hasFsHook = len(ctxs[0].FsHooks[round]) > 0 + hasColumn = len(ctxs[0].Columns[round]) > 0 + hasQParams = len(ctxs[0].QueryParams[round]) > 0 + ) + + if hasCoin || hasVAction || hasFsHook { + // The way the verifier runtime is that it will generate all the random coins at once and + // then, it runs all the verifier actions in parallel. What this action from the verifier + // is trying to do is to prepare a ctx-local FS state that can be later used in a join to + // derive a sound global FS state. Thus, we need it to run along side the "main" fs random + // coin generation. This is why this is declared as an FS hook and not as a VerifierAction. + comp.FiatShamirHooks.AppendToInner(round, &PreVortexVerifierStep{Ctxs: ctxs, Round: round}) + } + + if hasColumn || hasQParams { + comp.RegisterProverAction(round, &PreVortexProverStep{Ctxs: ctxs, Round: round}) + } + } + + comp.FiatShamirHooks.AppendToInner(ctxs[0].LastRound, &FsJoinHook{Ctxs: ctxs}) + comp.RegisterProverAction(ctxs[0].LastRound, &FsJoinProverStep{Ctxs: ctxs}) + comp.RegisterProverAction(ctxs[0].LastRound, &AssignVortexQuery{Ctxs: ctxs}) + comp.RegisterProverAction(ctxs[0].LastRound+1, &AssignVortexUAlpha{Ctxs: ctxs}) + comp.RegisterProverAction(ctxs[0].LastRound+2, &AssignVortexOpenedCols{Ctxs: ctxs}) + + // Importantly, the recursion compilation should happen after we added the vortex + // columns as they depends on the later. + for _, ctx := range ctxs { + selfrecursion.RecurseOverCustomCtx(comp, ctx.PcsCtx, ctx.Translator.Prefix) + } + + comp.RegisterVerifierAction(ctxs[0].LastRound, &CrossSegmentCheck{Ctxs: ctxs}) + } + + return def, &ctxs +} + +// initRecursionCtx initializes a new context +func initRecursionCtx(id string, target *wizard.CompiledIOP, tmpl *wizard.CompiledIOP) *recursionCtx { + return &recursionCtx{ + Translator: &compTranslator{Prefix: id, Target: target}, + Tmpl: tmpl, + } +} + +// captureCompPreVortex scans the content of tmpl to store the compilation infos of the +// CompiledIOP at the beginning of the compilation. The scanned wizard items are +// inserted into `comp` with a prefix `id` and recorded within the context. +func (ctx *recursionCtx) captureCompPreVortex(tmpl *wizard.CompiledIOP) { + + var ( + polyQuery = tmpl.PcsCtxs.(*vortex.Ctx).Query + lastRound = tmpl.QueriesParams.Round(polyQuery.QueryID) + + // This sanity-check ensures that the template has the right public inputs + _ = tmpl.GetPublicInputAccessor(constants.GrandProductPublicInput) + _ = tmpl.GetPublicInputAccessor(constants.GrandSumPublicInput) + _ = tmpl.GetPublicInputAccessor(constants.LogDerivativeSumPublicInput) + ) + + ctx.LastRound = lastRound + + for round := 0; round <= lastRound; round++ { + + ctx.Columns = append(ctx.Columns, []ifaces.Column{}) + ctx.QueryParams = append(ctx.QueryParams, []ifaces.Query{}) + ctx.VerifierActions = append(ctx.VerifierActions, []wizard.VerifierAction{}) + ctx.Coins = append(ctx.Coins, []coin.Info{}) + ctx.FsHooks = append(ctx.FsHooks, []wizard.VerifierAction{}) + + // Importantly, the coins are added before. Otherwise the 'assertConsistentRound' + // clause would not accept inserting columns or queries. + for _, cName := range tmpl.Coins.AllKeysAt(round) { + + if tmpl.Coins.IsSkippedFromVerifierTranscript(cName) { + continue + } + + coinInfo := tmpl.Coins.Data(cName) + coinInfo = ctx.Translator.InsertCoin(coinInfo) + ctx.Coins[round] = append(ctx.Coins[round], coinInfo) + ctx.Translator.Target.Coins.MarkAsSkippedFromVerifierTranscript(coinInfo.Name) + } + + for _, colName := range tmpl.Columns.AllKeysAt(round) { + + // filter the columns by status + var ( + col = tmpl.Columns.GetHandle(colName).(column.Natural) + status = col.Status() + ) + + if !status.IsPublic() { + // the column is not public so it is not part of the proof + continue + } + + var newCol ifaces.Column + + if tmpl.Precomputed.Exists(colName) { + newCol = ctx.Translator.InsertPrecomputed(col, tmpl.Precomputed.MustGet(colName)) + } else { + newCol = ctx.Translator.InsertColumn(col) + ctx.Columns[round] = append(ctx.Columns[round], newCol) + } + + ctx.Translator.Target.Columns.ExcludeFromProverFS(newCol.GetColID()) + } + + for _, qName := range tmpl.QueriesParams.AllKeysAt(round) { + + if tmpl.QueriesParams.IsSkippedFromVerifierTranscript(qName) { + continue + } + + // Importantly, the queries that we port should be already + // compiled in the tmpl. + if !tmpl.QueriesParams.IsIgnored(qName) { + panic("the template is invalid, all its queries should be compiled") + } + + // The uni-eval query is directly handled in a different section + // of the compilation. + if qName == polyQuery.QueryID { + continue + } + + // Note that we do not filter the already compiled queries + qInfo := tmpl.QueriesParams.Data(qName) + qInfo = ctx.Translator.InsertQueryParams(round, qInfo) + ctx.QueryParams[round] = append(ctx.QueryParams[round], qInfo) + ctx.Translator.Target.QueriesParams.MarkAsSkippedFromProverTranscript(qInfo.Name()) + } + + verifierActions := tmpl.SubVerifiers.Inner() + + for i := range verifierActions[round] { + + va := verifierActions[round][i] + if va.IsSkipped() { + continue + } + + ctx.VerifierActions[round] = append(ctx.VerifierActions[round], va) + } + + resetFs := tmpl.FiatShamirHooks.Inner() + + for _, fsHook := range resetFs[round] { + + if fsHook.IsSkipped() { + continue + } + + ctx.FsHooks[round] = append(ctx.VerifierActions[round], fsHook) + } + } +} + +func (ctx *recursionCtx) captureVortexCtx(tmpl *wizard.CompiledIOP) { + + var ( + srcVortexCtx = tmpl.PcsCtxs.(*vortex.Ctx) + comsByRound = srcVortexCtx.CommitmentsByRounds.Inner() + srcPrecomputed = srcVortexCtx.Items.Precomputeds.PrecomputedColums + ) + + ctx.ColumnIgnoredPrecomputed = make([]ifaces.Column, len(srcPrecomputed)) + for i := range srcPrecomputed { + ctx.ColumnIgnoredPrecomputed[i] = ctx.Translator.InsertPrecomputed( + srcPrecomputed[i].(column.Natural), + tmpl.Precomputed.MustGet(srcPrecomputed[i].GetColID()), + ) + } + + for _, coms := range comsByRound { + ctx.ColumnsIgnored = append(ctx.ColumnsIgnored, nil) + for _, comID := range coms { + com := tmpl.Columns.GetHandle(comID) + com = ctx.Translator.InsertColumn(com.(column.Natural)) + ctx.ColumnsIgnored[len(ctx.ColumnsIgnored)-1] = append(ctx.ColumnsIgnored[len(ctx.ColumnsIgnored)-1], com) + } + } + + if !srcVortexCtx.IsSelfrecursed || srcVortexCtx.ReplaceSisByMimc { + utils.Panic("the input vortex ctx is expected to be selfrecursed or having SIS replaced by MiMC. Please sure the input comp has been last compiled by Vortex with the option [vortex.MarkAsSelfRecursed]") + } + + dstVortexCtx := &vortex.Ctx{ + RunStateNamePrefix: ctx.Translator.Prefix, + BlowUpFactor: srcVortexCtx.BlowUpFactor, + DryTreshold: srcVortexCtx.DryTreshold, + CommittedRowsCount: srcVortexCtx.CommittedRowsCount, + NumCols: srcVortexCtx.NumCols, + MaxCommittedRound: srcVortexCtx.MaxCommittedRound, + NumOpenedCol: srcVortexCtx.NumOpenedCol, + VortexParams: srcVortexCtx.VortexParams, + SisParams: srcVortexCtx.SisParams, + // Although the srcVor + IsSelfrecursed: true, + CommitmentsByRounds: ctx.Translator.TranslateColumnVecVec(srcVortexCtx.CommitmentsByRounds), + DriedByRounds: ctx.Translator.TranslateColumnVecVec(srcVortexCtx.DriedByRounds), + PolynomialsTouchedByTheQuery: ctx.Translator.TranslateColumnSet(srcVortexCtx.PolynomialsTouchedByTheQuery), + ShadowCols: ctx.Translator.TranslateColumnSet(srcVortexCtx.ShadowCols), + Query: ctx.Translator.TranslateUniEval(ctx.LastRound, srcVortexCtx.Query), + } + + if srcVortexCtx.ReplaceSisByMimc { + panic("it should not replace by MiMC") + } + + ctx.Translator.Target.QueriesParams.MarkAsIgnored(dstVortexCtx.Query.QueryID) + + if srcVortexCtx.IsCommitToPrecomputed() { + dstVortexCtx.Items.Precomputeds.PrecomputedColums = ctx.Translator.TranslateColumnList(srcVortexCtx.Items.Precomputeds.PrecomputedColums) + dstVortexCtx.Items.Precomputeds.MerkleRoot = ctx.Translator.GetColumn(srcVortexCtx.Items.Precomputeds.MerkleRoot.GetColID()) + dstVortexCtx.Items.Precomputeds.CommittedMatrix = srcVortexCtx.Items.Precomputeds.CommittedMatrix + dstVortexCtx.Items.Precomputeds.DhWithMerkle = srcVortexCtx.Items.Precomputeds.DhWithMerkle + dstVortexCtx.Items.Precomputeds.Tree = srcVortexCtx.Items.Precomputeds.Tree + } + + dstVortexCtx.Items.Alpha = ctx.Translator.InsertCoin(srcVortexCtx.Items.Alpha) + dstVortexCtx.Items.Ualpha = ctx.Translator.InsertColumn(srcVortexCtx.Items.Ualpha.(column.Natural)) + dstVortexCtx.Items.Q = ctx.Translator.InsertCoin(srcVortexCtx.Items.Q) + dstVortexCtx.Items.OpenedColumns = ctx.Translator.InsertColumns(srcVortexCtx.Items.OpenedColumns) + dstVortexCtx.Items.MerkleProofs = ctx.Translator.InsertColumn(srcVortexCtx.Items.MerkleProofs.(column.Natural)) + dstVortexCtx.Items.MerkleRoots = ctx.Translator.TranslateColumnList(srcVortexCtx.Items.MerkleRoots) + + ctx.PcsCtx = dstVortexCtx +} diff --git a/prover/protocol/distributed/conglomeration/conglomeration_test.go b/prover/protocol/distributed/conglomeration/conglomeration_test.go new file mode 100644 index 00000000000..d40bc10c6ca --- /dev/null +++ b/prover/protocol/distributed/conglomeration/conglomeration_test.go @@ -0,0 +1,277 @@ +package conglomeration_test + +import ( + "testing" + + "github.com/consensys/linea-monorepo/prover/crypto/ringsis" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/accessors" + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/compiler" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/mimc" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/selfrecursion" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/vortex" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/conglomeration" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/constants" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" +) + +type ( + defineFuncType func(*wizard.Builder) + proverFuncType func(int) func(*wizard.ProverRuntime) + compilerSuite []func(*wizard.CompiledIOP) +) + +var ( + commonSisParams = &ringsis.Params{LogTwoBound: 16, LogTwoDegree: 1} + commonVortexStep = vortex.Compile( + 2, // the inverse-rate of the RS code + vortex.WithSISParams(commonSisParams), + vortex.ForceNumOpenedColumns(2), + vortex.PremarkAsSelfRecursed(), + ) + vortexOnlyCompilationSuite = []func(*wizard.CompiledIOP){ + commonVortexStep, + } + arcaneCompilationSuite = []func(*wizard.CompiledIOP){ + compiler.Arcane(1<<8, 1<<10, false), + commonVortexStep, + } + arcaneAndSelfRecCompilationSuite = []func(*wizard.CompiledIOP){ + compiler.Arcane(1, 1<<10, false), + commonVortexStep, + selfrecursion.SelfRecurse, + mimc.CompileMiMC, + compiler.Arcane(1, 1<<10, false), + commonVortexStep, + } + // arcaneFullRecSelfRecCompilationSuite = []func(*wizard.CompiledIOP){ + // compiler.Arcane(1<<8, 1<<10, false), + // commonVortexStep, + // selfrecursion.SelfRecurse, + // mimc.CompileMiMC, + // compiler.Arcane(1<<8, 1<<10, false), + // commonVortexStep, + // fullrecursion.FullRecursion(true), + // mimc.CompileMiMC, + // compiler.Arcane(1<<8, 1<<10, false), + // commonVortexStep, + // } +) + +type conglomerationTestCase struct { + define defineFuncType + prove proverFuncType + numProof int + suite compilerSuite +} + +func TestConglomerationPureVortexSingleRound(t *testing.T) { + + var ( + numCol = 16 + numRow = 16 + numProof = 2 + a []ifaces.Column + u query.UnivariateEval + ) + + define := func(builder *wizard.Builder) { + for i := 0; i < numCol; i++ { + a = append(a, builder.RegisterCommit(ifaces.ColIDf("a-%v", i), numRow)) + } + u = builder.CompiledIOP.InsertUnivariate(0, "u", a) + builder.InsertPublicInput(constants.GrandProductPublicInput, accessors.NewConstant(field.NewElement(1))) + builder.InsertPublicInput(constants.GrandSumPublicInput, accessors.NewConstant(field.NewElement(0))) + builder.InsertPublicInput(constants.LogDerivativeSumPublicInput, accessors.NewConstant(field.NewElement(0))) + } + + prover := func(k int) func(run *wizard.ProverRuntime) { + return func(run *wizard.ProverRuntime) { + ys := make([]field.Element, 0, len(a)) + for i := range a { + y := field.NewElement(uint64(i + k)) + run.AssignColumn(a[i].GetColID(), smartvectors.NewConstant(y, numRow)) + ys = append(ys, y) + } + run.AssignUnivariate(u.QueryID, field.NewElement(0), ys...) + } + } + + runConglomerationTestCase(t, conglomerationTestCase{ + define: define, + prove: prover, + numProof: numProof, + suite: vortexOnlyCompilationSuite, + }) +} + +func TestConglomerationPureVortexMultiRound(t *testing.T) { + + var ( + numRound = 4 + numCol = 4 + numRow = 16 + numProof = 16 + a []ifaces.Column + ) + + define := func(builder *wizard.Builder) { + + allCols := make([]ifaces.Column, 0, numCol*numRound) + + for round := 0; round < numRound; round++ { + + if round > 0 { + _ = builder.InsertCoin(round, coin.Namef("c-%v", round), coin.Field) + } + + roundCols := make([]ifaces.Column, 0, numCol) + + for i := 0; i < numCol; i++ { + newCol := builder.InsertCommit(round, ifaces.ColIDf("a-%v-%v", round, i), numRow) + roundCols = append(roundCols, newCol) + } + + if round > 0 { + builder.SubProvers.AppendToInner(round, func(run *wizard.ProverRuntime) { + for i := range roundCols { + x := field.NewElement(uint64(round*numCol + i)) + run.AssignColumn(roundCols[i].GetColID(), smartvectors.NewConstant(x, numRow)) + } + + if round == numRound-1 { + + } + }) + } + + if round == 0 { + a = roundCols + } + + allCols = append(allCols, roundCols...) + } + + u := builder.CompiledIOP.InsertUnivariate(numRound-1, "u", allCols) + + builder.CompiledIOP.SubProvers.AppendToInner(numRound-1, func(run *wizard.ProverRuntime) { + ys := make([]field.Element, 0, len(allCols)) + for _, col := range allCols { + ys = append(ys, col.GetColAssignmentAt(run, 0)) + } + run.AssignUnivariate(u.QueryID, field.NewElement(0), ys...) + }) + + builder.InsertPublicInput(constants.GrandProductPublicInput, accessors.NewConstant(field.NewElement(1))) + builder.InsertPublicInput(constants.GrandSumPublicInput, accessors.NewConstant(field.NewElement(0))) + builder.InsertPublicInput(constants.LogDerivativeSumPublicInput, accessors.NewConstant(field.NewElement(0))) + } + + prover := func(k int) func(run *wizard.ProverRuntime) { + return func(run *wizard.ProverRuntime) { + for i := range a { + y := field.NewElement(uint64(i + k)) + run.AssignColumn(a[i].GetColID(), smartvectors.NewConstant(y, numRow)) + } + } + } + + runConglomerationTestCase(t, conglomerationTestCase{ + define: define, + prove: prover, + numProof: numProof, + suite: vortexOnlyCompilationSuite, + }) +} + +func TestConglomerationLookup(t *testing.T) { + + logrus.SetLevel(logrus.FatalLevel) + + tcs := []struct { + name string + suite compilerSuite + }{ + { + name: "arcane", + suite: arcaneCompilationSuite, + }, + { + name: "arcane/self-recursion", + suite: arcaneAndSelfRecCompilationSuite, + }, + // { + // name: "arcane/full-recursion/self-recursion", + // suite: arcaneFullRecSelfRecCompilationSuite, + // }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + + var ( + numCol = 16 + numRow = 16 + numProof = 4 + a []ifaces.Column + ) + + define := func(builder *wizard.Builder) { + for i := 0; i < numCol; i++ { + a = append(a, builder.RegisterCommit(ifaces.ColIDf("a-%v", i), numRow)) + builder.Range(ifaces.QueryIDf("range-%v", i), a[i], 1<<8) + } + + builder.InsertPublicInput(constants.GrandProductPublicInput, accessors.NewConstant(field.NewElement(1))) + builder.InsertPublicInput(constants.GrandSumPublicInput, accessors.NewConstant(field.NewElement(0))) + builder.InsertPublicInput(constants.LogDerivativeSumPublicInput, accessors.NewConstant(field.NewElement(0))) + } + + prover := func(k int) func(run *wizard.ProverRuntime) { + return func(run *wizard.ProverRuntime) { + for i := range a { + y := field.NewElement(uint64(i + k)) + run.AssignColumn(a[i].GetColID(), smartvectors.NewConstant(y, numRow)) + } + } + } + + runConglomerationTestCase(t, conglomerationTestCase{ + define: define, + prove: prover, + numProof: numProof, + suite: tc.suite, + }) + }) + } +} + +func runConglomerationTestCase(t *testing.T, tc conglomerationTestCase) { + + var ( + numProof = tc.numProof + tmpl = wizard.Compile(wizard.DefineFunc(tc.define), tc.suite...) + congDef, ctxsPHolder = conglomeration.ConglomerateDefineFunc(tmpl, numProof) + cong = wizard.Compile(congDef, dummy.CompileAtProverLvl) + ctxs = *ctxsPHolder + lastRound = ctxs[0].LastRound + ) + + witnesses := make([]conglomeration.Witness, numProof) + for i := range witnesses { + runtime := wizard.RunProverUntilRound(tmpl, tc.prove(i), lastRound+1) + witnesses[i] = conglomeration.ExtractWitness(runtime) + } + + proof := wizard.Prove(cong, conglomeration.ProveConglomeration(ctxs, witnesses)) + err := wizard.Verify(cong, proof) + + require.NoError(t, err) +} diff --git a/prover/protocol/distributed/conglomeration/cross_segment_consistency.go b/prover/protocol/distributed/conglomeration/cross_segment_consistency.go new file mode 100644 index 00000000000..431aa978776 --- /dev/null +++ b/prover/protocol/distributed/conglomeration/cross_segment_consistency.go @@ -0,0 +1,105 @@ +package conglomeration + +import ( + "errors" + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/constants" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" +) + +// crossSegmentCheclk is a verifier action that performs cross-segment checks: +// for instance, it checks that the log-derivative sums all sums to 0 and that +// the grand product is 1. The goal is to ensure that the lookups, permutations +// in the original protocol are satisfied. +type CrossSegmentCheck struct { + Ctxs []*recursionCtx + skip bool +} + +// Run implements the [wizard.VerifierAction], it handles the cross checks over +// the public inputs. for example the global sum over the LogDerivativeSum from +// different segments should be zero. +func (pir *CrossSegmentCheck) Run(run wizard.Runtime) error { + + var ( + logDerivSumAcc, grandSumAcc field.Element + grandProductAcc = field.One() + err error + ) + + for _, ctx := range pir.Ctxs { + + var ( + wrappedRun = &runtimeTranslator{Prefix: ctx.Translator.Prefix, Rt: run} + tmpl = ctx.Tmpl + logDerivSum = tmpl.GetPublicInputAccessor(constants.LogDerivativeSumPublicInput).GetVal(wrappedRun) + grandProd = tmpl.GetPublicInputAccessor(constants.GrandProductPublicInput).GetVal(wrappedRun) + grandSum = tmpl.GetPublicInputAccessor(constants.GrandSumPublicInput).GetVal(wrappedRun) + ) + + logDerivSumAcc.Add(&logDerivSumAcc, &logDerivSum) + grandSumAcc.Add(&grandSumAcc, &grandSum) + grandProductAcc.Mul(&grandProductAcc, &grandProd) + } + + if logDerivSumAcc != field.Zero() { + err = errors.Join(err, fmt.Errorf("the global sum over LogDerivSumParams is not zero,"+ + " maybe the same coin over different modules has different values")) + } + + if grandProductAcc != field.One() { + err = errors.Join(err, fmt.Errorf("the global product overGrandProductParams is not 1,"+ + " maybe the same coin over different modules has different values")) + } + + if grandSumAcc != field.Zero() { + err = errors.Join(err, fmt.Errorf("the global sum over GrandSumParams is not zero,"+ + " maybe the same coin over different modules has different values")) + } + + if err != nil { + return fmt.Errorf("[conglomeration.crossSegmentConsistency] %w", err) + } + + return nil +} + +// RunGnark implements the [wizard.VerifierAction] +func (pir *CrossSegmentCheck) RunGnark(api frontend.API, run wizard.GnarkRuntime) { + + var ( + logDerivSumAcc = frontend.Variable(0) + grandSumAcc = frontend.Variable(0) + grandProductAcc = frontend.Variable(1) + ) + + for _, ctx := range pir.Ctxs { + + var ( + wrappedRun = &gnarkRuntimeTranslator{Prefix: ctx.Translator.Prefix, Rt: run} + tmpl = ctx.Tmpl + logDerivSum = tmpl.GetPublicInputAccessor(constants.LogDerivativeSumPublicInput).GetFrontendVariable(api, wrappedRun) + grandProd = tmpl.GetPublicInputAccessor(constants.GrandProductPublicInput).GetFrontendVariable(api, wrappedRun) + grandSum = tmpl.GetPublicInputAccessor(constants.GrandSumPublicInput).GetFrontendVariable(api, wrappedRun) + ) + + logDerivSumAcc = api.Add(logDerivSumAcc, logDerivSum) + grandSumAcc = api.Add(grandSumAcc, grandSum) + grandProductAcc = api.Mul(grandProductAcc, grandProd) + } + + api.AssertIsEqual(logDerivSumAcc, field.Zero()) + api.AssertIsEqual(grandProductAcc, field.One()) + api.AssertIsEqual(grandSumAcc, field.Zero()) +} + +func (v *CrossSegmentCheck) Skip() { + v.skip = true +} + +func (v *CrossSegmentCheck) IsSkipped() bool { + return v.skip +} diff --git a/prover/protocol/distributed/conglomeration/prover.go b/prover/protocol/distributed/conglomeration/prover.go new file mode 100644 index 00000000000..42c935395e5 --- /dev/null +++ b/prover/protocol/distributed/conglomeration/prover.go @@ -0,0 +1,241 @@ +package conglomeration + +import ( + "strings" + + "github.com/consensys/linea-monorepo/prover/crypto/mimc" + "github.com/consensys/linea-monorepo/prover/crypto/state-management/smt" + vCom "github.com/consensys/linea-monorepo/prover/crypto/vortex" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/vortex" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/utils" +) + +const ( + subProofInStatePrefixStr = ".subProof" + finalFsStateInStateStr = ".finalFsState" +) + +// Witness is a collection of inputs corresponding to a segment proof to provide +// to the main prover of a conglomerate comp. +type Witness struct { + Proof wizard.Proof + CommittedMatrices []vCom.EncodedMatrix + SisHashes [][]field.Element + Trees []*smt.Tree + FinalFsState []field.Element +} + +// PreVortexProverStep is a step replicating the prover of the tmpl at round +// `Round` before the Vortex compilation step. It works by adding columns from +// a wizard proof stored in the prover runtime. The proof is fetched from the +// runtime state. That means the prover step should be run after the proof has +// been attached to the runtime. +type PreVortexProverStep struct { + Ctxs []*recursionCtx + Round int +} + +// FsJoinProverStep is prover step setting the fiat-shamir state of the main +// transcript to the hash of the final states of the subproofs. +type FsJoinProverStep struct { + Ctxs []*recursionCtx + Round int +} + +// AssignVortexQuery assigns the query for all the subproofs. +type AssignVortexQuery struct { + Ctxs []*recursionCtx +} + +// AssignVortexUAlpha assigns the UAlpha column for all the subproofs. As +// for [PreVortexVerifierStep], this step should be run after the corresponding +// proofs have been added to the runtime states. +type AssignVortexUAlpha struct { + Ctxs []*recursionCtx +} + +// AssignVortexOpenedCols assigns the OpenedCols for all the subproofs. As +// for [PreVortexVerifierStep], this step should be run after the corresponding +// proofs have been added to the runtime states. +type AssignVortexOpenedCols struct { + Ctxs []*recursionCtx +} + +// ProveConglomeration returns the main prover step of the conglomeration wizard. +// It takes a list of [Witness] as input and complete the list with the last +// value. +func ProveConglomeration(ctxs []*recursionCtx, witnesses []Witness) wizard.ProverStep { + + if len(witnesses) > len(ctxs) { + utils.Panic("More witnesses than ctxs, numWitnesses: %v, numCtxs: %v", len(witnesses), len(ctxs)) + } + + return func(run *wizard.ProverRuntime) { + + for i, witness := range witnesses { + storeWitnessInState(run, ctxs[i], witness) + } + + for i := len(witnesses); i < len(ctxs); i++ { + storeWitnessInState(run, ctxs[i], witnesses[len(witnesses)-1]) + } + } +} + +func storeWitnessInState(run *wizard.ProverRuntime, ctx *recursionCtx, witness Witness) { + + var ( + prefix = ctx.Translator.Prefix + lastRound = ctx.LastRound + ) + + run.State.InsertNew(prefix+subProofInStatePrefixStr, witness.Proof) + run.State.InsertNew(prefix+finalFsStateInStateStr, witness.FinalFsState) + + for round := 0; round <= lastRound; round++ { + + if len(witness.CommittedMatrices) > round && witness.CommittedMatrices[round] != nil { + run.State.InsertNew(ctx.PcsCtx.VortexProverStateName(round), witness.CommittedMatrices[round]) + } + + if len(witness.SisHashes) > round && witness.SisHashes[round] != nil { + run.State.InsertNew(ctx.PcsCtx.SisHashName(round), witness.SisHashes[round]) + } + + if len(witness.Trees) > round && witness.Trees[round] != nil { + run.State.InsertNew(ctx.PcsCtx.MerkleTreeName(round), witness.Trees[round]) + } + } +} + +// ExtractWitness extracts a [Witness] from a prover runtime toward being conglomerated. +func ExtractWitness(run *wizard.ProverRuntime) Witness { + + var ( + pcs = run.Spec.PcsCtxs.(*vortex.Ctx) + committedMatrices []vCom.EncodedMatrix + sisHashes [][]field.Element + trees []*smt.Tree + lastRound = run.Spec.QueriesParams.Round(pcs.Query.QueryID) + ) + + for round := 0; round <= lastRound; round++ { + + var ( + committedMatrix, _ = run.State.TryGet(pcs.VortexProverStateName(round)) + sisHash, _ = run.State.TryGet(pcs.SisHashName(round)) + tree, _ = run.State.TryGet(pcs.MerkleTreeName(round)) + ) + + if committedMatrix != nil { + committedMatrices = append(committedMatrices, committedMatrix.(vCom.EncodedMatrix)) + sisHashes = append(sisHashes, sisHash.([]field.Element)) + trees = append(trees, tree.(*smt.Tree)) + } else { + committedMatrices = append(committedMatrices, nil) + sisHashes = append(sisHashes, nil) + trees = append(trees, nil) + } + } + + return Witness{ + Proof: run.ExtractProof(), + CommittedMatrices: committedMatrices, + SisHashes: sisHashes, + Trees: trees, + FinalFsState: run.FS.State(), + } +} + +func (pa PreVortexProverStep) Run(run *wizard.ProverRuntime) { + for _, ctx := range pa.Ctxs { + + var ( + prefix = ctx.Translator.Prefix + proof = run.State.MustGet(prefix + subProofInStatePrefixStr).(wizard.Proof) + queriesParams = ctx.QueryParams[pa.Round] + colums = ctx.Columns[pa.Round] + columnIgnored = ctx.ColumnsIgnored[pa.Round] + ) + + for _, col := range colums { + name := unprefix(prefix, col.GetColID()) + + if col.(column.Natural).Status() == column.VerifyingKey { + // those don't need to be assigned and not included in the + // proof either. + continue + } + + run.AssignColumn(col.GetColID(), proof.Messages.MustGet(name)) + } + + for _, col := range columnIgnored { + run.AssignColumn(col.GetColID(), smartvectors.NewConstant(field.Zero(), col.Size())) + } + + for _, param := range queriesParams { + name := unprefix(prefix, param.Name()) + run.QueriesParams.InsertNew(param.Name(), proof.QueriesParams.MustGet(name)) + } + } +} + +func (pa AssignVortexQuery) Run(run *wizard.ProverRuntime) { + for _, ctx := range pa.Ctxs { + + var ( + prefix = ctx.Translator.Prefix + proof = run.State.MustGet(prefix + subProofInStatePrefixStr).(wizard.Proof) + name = unprefix(prefix, ctx.PcsCtx.Query.QueryID) + ) + + run.QueriesParams.InsertNew(ctx.PcsCtx.Query.QueryID, proof.QueriesParams.MustGet(name)) + } +} + +func (pa AssignVortexUAlpha) Run(run *wizard.ProverRuntime) { + for _, ctx := range pa.Ctxs { + // Since all the context of the pcs is translated, this does not + // need to run over a translated prover runtime. + ctx.PcsCtx.ComputeLinearCombFromRsMatrix(run) + } +} + +func (pa AssignVortexOpenedCols) Run(run *wizard.ProverRuntime) { + for _, ctx := range pa.Ctxs { + // Since all the context of the pcs is translated, this does not + // need to run over a translated prover runtime. + ctx.PcsCtx.OpenSelectedColumns(run) + } +} + +func unprefix[T ~string](prefix string, name T) T { + p, n := string(prefix)+".", string(name) + r := strings.TrimPrefix(n, p) + return T(r) +} + +func (pa *FsJoinProverStep) Run(run *wizard.ProverRuntime) { + + mainState := field.NewElement(0) + + for _, ctx := range pa.Ctxs { + + var ( + prefix = ctx.Translator.Prefix + fsState = run.State.MustGet(prefix + finalFsStateInStateStr).([]field.Element) + ) + + mainState = mimc.BlockCompression(mainState, fsState[0]) + } + + newState := []field.Element{mainState} + + run.FS.SetState(newState) + run.FiatShamirHistory[pa.Round][1] = newState +} diff --git a/prover/protocol/distributed/conglomeration/translator.go b/prover/protocol/distributed/conglomeration/translator.go new file mode 100644 index 00000000000..4123cf7647e --- /dev/null +++ b/prover/protocol/distributed/conglomeration/translator.go @@ -0,0 +1,414 @@ +package conglomeration + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/crypto/fiatshamir" + "github.com/consensys/linea-monorepo/prover/crypto/mimc/gkrmimc" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/column/verifiercol" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/utils" + "github.com/consensys/linea-monorepo/prover/utils/collection" +) + +var ( + _ wizard.Runtime = &runtimeTranslator{} + _ wizard.GnarkRuntime = &gnarkRuntimeTranslator{} +) + +// compTranslator is a builder struct for building a target [wizard.CompiledIOP] +// instances from another source [wizard.CompiledIOP]. All items in the built +// compiled IOP are prefixed with an identifier. +type compTranslator struct { + Prefix string + Target *wizard.CompiledIOP +} + +// runtimeTranslator is an adapter structure prefixing every ColID and QueryID and +// coin.Name with a prefix string. +type runtimeTranslator struct { + Prefix string + Rt wizard.Runtime +} + +// gnarkRuntimeTranslator is as [runtimeTranslator] but for [wizard.GnarkRuntime] +type gnarkRuntimeTranslator struct { + Prefix string + Rt wizard.GnarkRuntime +} + +// InsertColumn inserts a new column in the target compiled IOP. The column name +// is prefixed with comp.Prefix. The function checks that the passed column does +// not have a precomputed status (e.g. either precomputed or verifying key). +func (comp *compTranslator) InsertColumn(col column.Natural) ifaces.Column { + + switch col.Status() { + case column.Precomputed, column.VerifyingKey: + panic("cannot insert a precomputed or verifying key column as normal column. Use [InsertPrecomputed] instead") + } + name := ifaces.ColID(comp.Prefix) + "." + col.ID + return comp.Target.InsertColumn(col.Round(), name, col.Size(), col.Status()) +} + +// InsertPrecomputed inserts a new column as a precomputed column to the target +// compiled IOP. To differ with [InsertColumn], this method does also add the +// column to the list of precomputed columns. +func (comp *compTranslator) InsertPrecomputed(col column.Natural, ass ifaces.ColAssignment) ifaces.Column { + name := ifaces.ColID(comp.Prefix) + "." + col.ID + + switch col.Status() { + case column.VerifyingKey: + // assertedly, the round of a precomputed column is always 0 + col := comp.Target.InsertColumn(0, name, col.Size(), col.Status()) + comp.Target.Precomputed.InsertNew(name, ass) + return col + case column.Precomputed, column.Ignored: + return comp.Target.InsertPrecomputed(name, ass) + default: + panic(fmt.Sprintf("not a precomputed column: status=%v name=%v", col.Status().String(), col.ID)) + } +} + +// InsertColumns inserts a list of columns in the target compiled IOP by adding +// a prefix to their names. The inputs columns are expected to be of type +// Natural or this will lead to a panic. +func (comp *compTranslator) InsertColumns(cols []ifaces.Column) []ifaces.Column { + res := make([]ifaces.Column, 0, len(cols)) + for i := range cols { + r := comp.InsertColumn(cols[i].(column.Natural)) + res = append(res, r) + } + return res +} + +// GetColumn returns a column from the target compiled IOP. +func (comp *compTranslator) GetColumn(name ifaces.ColID) ifaces.Column { + name = ifaces.ColID(comp.Prefix) + "." + name + return comp.Target.Columns.GetHandle(name) +} + +// ColumnExists returns a boolean indicating of the column is already +// registered in the translator. +func (comp *compTranslator) ColumnExists(name ifaces.ColID) bool { + name = ifaces.ColID(comp.Prefix) + "." + name + return comp.Target.Columns.Exists(name) +} + +// InsertCoin inserts a new coin in the target compiled IOP. The coin name +// is prefixed with the comp.Prefix. +func (comp *compTranslator) InsertCoin(info coin.Info) coin.Info { + name := coin.Name(comp.Prefix) + "." + info.Name + switch info.Type { + case coin.IntegerVec: + return comp.Target.InsertCoin(info.Round, name, info.Type, info.Size, info.UpperBound) + case coin.Field: + return comp.Target.InsertCoin(info.Round, name, info.Type) + default: + panic("unknown coin type") + } +} + +// GetCoin returns a coin with the prefixed name in the target compiled IOP. +// It panics if the prefixed coin is not found. +func (comp *compTranslator) GetCoin(name coin.Name) coin.Info { + name = coin.Name(comp.Prefix) + "." + name + return comp.Target.Coins.Data(name) +} + +// InsertQueryParams inserts a new query in the target compiled IOP prefixing the +// query name however the inner-fields of the query are not prefixed or translated. +// So it should be preferrably applied only over "Ignored" queries as the content of +// the inserted query will be invalid. +func (comp *compTranslator) InsertQueryParams(round int, q ifaces.Query) ifaces.Query { + name := ifaces.QueryID(comp.Prefix) + "." + q.Name() + + var q2 ifaces.Query + switch q := q.(type) { + case query.UnivariateEval: + q2 = query.NewUnivariateEval(name, q.Pols...) + case query.LocalOpening: + q2 = query.NewLocalOpening(name, q.Pol) + case query.InnerProduct: + q2 = query.NewInnerProduct(name, q.A, q.Bs...) + case query.GrandProduct: + q2 = query.NewGrandProduct(q.Round, q.Inputs, name) + case query.LogDerivativeSum: + q2 = query.NewLogDerivativeSum(q.Round, q.Inputs, name) + default: + panic("unknown query type") + } + + comp.Target.QueriesParams.AddToRound(round, name, q2) + comp.Target.QueriesParams.MarkAsIgnored(q2.Name()) + return q2 +} + +// TranslateColumnList translates a collection of pre-inserted columns. +// If one of the columns provided in the list is nil, it will be ignored +// and the function will return nil at the same position in the returned +// list of column. If the column is non-nil but not found in the translated +// comp, then it is assumed the column is a verifier col and the column is +// returned as is. +func (comp *compTranslator) TranslateColumnList(cols []ifaces.Column) []ifaces.Column { + res := make([]ifaces.Column, 0, len(cols)) + for _, col := range cols { + + if col == nil { + res = append(res, nil) + continue + } + + if !comp.ColumnExists(col.GetColID()) { + if _, ok := col.(verifiercol.VerifierCol); !ok { + utils.Panic("expects all the unregistered methods to be verifiercol, but got type=%T for column name=%v", col, col.GetColID()) + } + } + + res = append(res, comp.GetColumn(col.GetColID())) + } + return res +} + +// TranslateColumnVecVec translates a collection of pre-inserted columns +func (comp *compTranslator) TranslateColumnVecVec(cols collection.VecVec[ifaces.ColID]) collection.VecVec[ifaces.ColID] { + var res = collection.NewVecVec[ifaces.ColID]() + for r, vec := range cols.Inner() { + for _, c := range vec { + + // If it does not exists, then it is assumed to be a verifier column + if !comp.ColumnExists(c) { + res.AppendToInner(r, c) + continue + } + + res.AppendToInner(r, comp.GetColumn(c).GetColID()) + } + } + return res +} + +// TranslateColumnSet translates a set of pre-inserted columns +func (comp *compTranslator) TranslateColumnSet(cols map[ifaces.ColID]struct{}) map[ifaces.ColID]struct{} { + var res = make(map[ifaces.ColID]struct{}) + for col := range cols { + + // If it does not exists, then it is assumed to be a verifier column + if !comp.ColumnExists(col) { + res[col] = struct{}{} + continue + } + + res[comp.GetColumn(col).GetColID()] = struct{}{} + } + return res +} + +// TranslateUniEval returns a copied UnivariateEval query with the columns translated +// and the names translated. The returned query is registered in the translator comp. +func (comp *compTranslator) TranslateUniEval(round int, q query.UnivariateEval) query.UnivariateEval { + newPols := make([]ifaces.Column, len(q.Pols)) + for i := range newPols { + if _, ok := q.Pols[i].(verifiercol.VerifierCol); ok { + newPols[i] = q.Pols[i] + continue + } + + newPols[i] = comp.GetColumn(q.Pols[i].GetColID()) + } + var res = query.NewUnivariateEval(q.QueryID, newPols...) + return comp.InsertQueryParams(round, res).(query.UnivariateEval) +} + +func (run *runtimeTranslator) GetColumn(name ifaces.ColID) ifaces.ColAssignment { + name = ifaces.ColID(run.Prefix) + "." + name + return run.Rt.GetColumn(name) +} + +func (run *runtimeTranslator) GetColumnAt(name ifaces.ColID, pos int) field.Element { + name = ifaces.ColID(run.Prefix) + "." + name + return run.Rt.GetColumnAt(name, pos) +} + +func (run *runtimeTranslator) GetRandomCoinField(name coin.Name) field.Element { + name = coin.Name(run.Prefix) + "." + name + return run.Rt.GetRandomCoinField(name) +} + +func (run *runtimeTranslator) GetRandomCoinIntegerVec(name coin.Name) []int { + name = coin.Name(run.Prefix) + "." + name + return run.Rt.GetRandomCoinIntegerVec(name) +} + +func (run *runtimeTranslator) GetParams(id ifaces.QueryID) ifaces.QueryParams { + id = ifaces.QueryID(run.Prefix) + "." + id + return run.Rt.GetParams(id) +} + +func (run *runtimeTranslator) GetSpec() *wizard.CompiledIOP { + return run.Rt.GetSpec() +} + +func (run *runtimeTranslator) GetPublicInput(name string) field.Element { + name = run.Prefix + "." + name + return run.Rt.GetPublicInput(name) +} + +func (run *runtimeTranslator) GetGrandProductParams(name ifaces.QueryID) query.GrandProductParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetGrandProductParams(name) +} + +func (run *runtimeTranslator) GetLogDerivSumParams(name ifaces.QueryID) query.LogDerivSumParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetLogDerivSumParams(name) +} + +func (run *runtimeTranslator) GetLocalPointEvalParams(name ifaces.QueryID) query.LocalOpeningParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetLocalPointEvalParams(name) +} + +func (run *runtimeTranslator) GetInnerProductParams(name ifaces.QueryID) query.InnerProductParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetInnerProductParams(name) +} + +func (run *runtimeTranslator) GetUnivariateEval(name ifaces.QueryID) query.UnivariateEval { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetUnivariateEval(name) +} + +func (run *runtimeTranslator) GetUnivariateParams(name ifaces.QueryID) query.UnivariateEvalParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetUnivariateParams(name) +} + +func (run *runtimeTranslator) Fs() *fiatshamir.State { + return run.Rt.Fs() +} + +func (run *runtimeTranslator) FsHistory() [][2][]field.Element { + return run.Rt.FsHistory() +} + +func (run *runtimeTranslator) InsertCoin(name coin.Name, value any) { + name = coin.Name(run.Prefix) + "." + name + run.Rt.InsertCoin(name, value) +} + +func (run *runtimeTranslator) GetState(name string) (any, bool) { + name = run.Prefix + "." + name + return run.Rt.GetState(name) +} + +func (run *runtimeTranslator) SetState(name string, value any) { + name = run.Prefix + "." + name + run.Rt.SetState(name, value) +} + +func (run *runtimeTranslator) GetQuery(name ifaces.QueryID) ifaces.Query { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetQuery(name) +} + +func (run *gnarkRuntimeTranslator) GetColumn(name ifaces.ColID) []frontend.Variable { + name = ifaces.ColID(run.Prefix) + "." + name + return run.Rt.GetColumn(name) +} + +func (run *gnarkRuntimeTranslator) GetColumnAt(name ifaces.ColID, at int) frontend.Variable { + name = ifaces.ColID(run.Prefix) + "." + name + return run.Rt.GetColumnAt(name, at) +} + +func (run *gnarkRuntimeTranslator) GetRandomCoinField(name coin.Name) frontend.Variable { + name = coin.Name(run.Prefix) + "." + name + return run.Rt.GetRandomCoinField(name) +} + +func (run *gnarkRuntimeTranslator) GetRandomCoinIntegerVec(name coin.Name) []frontend.Variable { + name = coin.Name(run.Prefix) + "." + name + return run.Rt.GetRandomCoinIntegerVec(name) +} + +func (run *gnarkRuntimeTranslator) GetParams(id ifaces.QueryID) ifaces.GnarkQueryParams { + id = ifaces.QueryID(run.Prefix) + "." + id + return run.Rt.GetParams(id) +} + +func (run *gnarkRuntimeTranslator) GetSpec() *wizard.CompiledIOP { + return run.Rt.GetSpec() +} + +func (run *gnarkRuntimeTranslator) GetPublicInput(api frontend.API, name string) frontend.Variable { + name = run.Prefix + "." + name + return run.Rt.GetPublicInput(api, name) +} + +func (run *gnarkRuntimeTranslator) GetGrandProductParams(name ifaces.QueryID) query.GnarkGrandProductParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetGrandProductParams(name) +} + +func (run *gnarkRuntimeTranslator) GetLogDerivSumParams(name ifaces.QueryID) query.GnarkLogDerivSumParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetLogDerivSumParams(name) +} + +func (run *gnarkRuntimeTranslator) GetLocalPointEvalParams(name ifaces.QueryID) query.GnarkLocalOpeningParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetLocalPointEvalParams(name) +} + +func (run *gnarkRuntimeTranslator) GetInnerProductParams(name ifaces.QueryID) query.GnarkInnerProductParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetInnerProductParams(name) +} + +func (run *gnarkRuntimeTranslator) GetUnivariateEval(name ifaces.QueryID) query.UnivariateEval { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetUnivariateEval(name) +} + +func (run *gnarkRuntimeTranslator) GetUnivariateParams(name ifaces.QueryID) query.GnarkUnivariateEvalParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetUnivariateParams(name) +} + +func (run *gnarkRuntimeTranslator) Fs() *fiatshamir.GnarkFiatShamir { + return run.Rt.Fs() +} + +func (run *gnarkRuntimeTranslator) FsHistory() [][2][]frontend.Variable { + return run.Rt.FsHistory() +} + +func (run *gnarkRuntimeTranslator) GetHasherFactory() *gkrmimc.HasherFactory { + return run.Rt.GetHasherFactory() +} + +func (run *gnarkRuntimeTranslator) InsertCoin(name coin.Name, value any) { + name = coin.Name(run.Prefix) + "." + name + run.Rt.InsertCoin(name, value) +} + +func (run *gnarkRuntimeTranslator) GetState(name string) (any, bool) { + name = run.Prefix + "." + name + return run.Rt.GetState(name) +} + +func (run *gnarkRuntimeTranslator) SetState(name string, value any) { + name = run.Prefix + "." + name + run.Rt.SetState(name, value) +} + +func (run *gnarkRuntimeTranslator) GetQuery(name ifaces.QueryID) ifaces.Query { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetQuery(name) +} diff --git a/prover/protocol/distributed/conglomeration/verifier.go b/prover/protocol/distributed/conglomeration/verifier.go new file mode 100644 index 00000000000..cbc3f8fb025 --- /dev/null +++ b/prover/protocol/distributed/conglomeration/verifier.go @@ -0,0 +1,355 @@ +package conglomeration + +import ( + "errors" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/crypto/fiatshamir" + "github.com/consensys/linea-monorepo/prover/crypto/mimc" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" +) + +const ( + fiatShamirHistoryStr = "fiat-shamir-history" + fiatShamirTranscriptStr = "fiat-shamir-transcript" +) + +var _ wizard.Runtime = &RuntimeWithReplacedFS{} + +// PreVortexVerifierStep is a step replicating the verifier of the tmpl at round +// `Round` before the Vortex compilation step. +type PreVortexVerifierStep struct { + Ctxs []*recursionCtx + Round int + isSkipped bool +} + +// RuntimeWithReplacedFS is a runtime that wraps another runtime and replaces +// the returns of [Fs] and [FsHistory] with the ones provided in the struct. +type RuntimeWithReplacedFS struct { + wizard.Runtime + FS *fiatshamir.State + FiatShamirHistory [][2][]field.Element +} + +// GnarkRuntimeWithReplacedFS is a GnarkRuntime that wraps another runtime and +// replaces the returns of [Fs] and [FsHistory] with the ones provided in the struct. +type GnarkRuntimeWithReplacedFS struct { + wizard.GnarkRuntime + FS *fiatshamir.GnarkFiatShamir + FiatShamirHistory [][2][]frontend.Variable +} + +// FsJoinHook is a fiat-shamir hook whose purpose is to join the interal +// fiat-shamir states of the segment proofs into the main one. It works +// by setting the main fs-state as the hash of the internal states. +type FsJoinHook struct { + Ctxs []*recursionCtx + isSkipped bool +} + +// SubFsInitialize is a fiat-shamir hook whose purpose is to initialize the +// internal fiat-shamir states of the segment proofs. It is set as a FS hook +// as it is guaranteed to be run before any prover/verifier step, ensuring the +// fs states are available at the beginning. +type SubFsInitialize struct { + Ctxs []*recursionCtx + isSkipped bool +} + +func (pa PreVortexVerifierStep) Run(run wizard.Runtime) error { + + var err error + + for _, ctx := range pa.Ctxs { + + generateRandomCoins(run, ctx, pa.Round) + + // Wraps the runtime into a translation adapter + var ( + wrappedRun = &runtimeTranslator{Prefix: ctx.Translator.Prefix, Rt: run} + ) + + // Copy the verifier actions of the template into the target + for _, va := range ctx.VerifierActions[pa.Round] { + err = errors.Join(err, va.Run(wrappedRun)) + } + } + + return err +} + +// generateRandomCoins generates all the coins for the current round +// so that they are made available to the forthcoming verifier actions. +func generateRandomCoins(run wizard.Runtime, ctx *recursionCtx, currRound int) { + + var ( + spec = run.GetSpec() + // Wraps the runtime into a translation adapter, first to get the FS state + // and history. + wrappedRun wizard.Runtime = &runtimeTranslator{Prefix: ctx.Translator.Prefix, Rt: run} + fsAny, _ = wrappedRun.GetState(fiatShamirTranscriptStr) + fsHistoryAny, _ = wrappedRun.GetState(fiatShamirHistoryStr) + fs = fsAny.(*fiatshamir.State) + fsHistory = fsHistoryAny.([][2][]field.Element) + initialState = fs.State() + ) + + // Wraps it a second time + wrappedRun = &RuntimeWithReplacedFS{ + Runtime: wrappedRun, + FS: fs, + FiatShamirHistory: fsHistory, + } + + if currRound > 0 { + + cols := ctx.Columns[currRound-1] + for _, col := range cols { + + name := unprefix(ctx.Translator.Prefix, col.GetColID()) + if ctx.Tmpl.Columns.IsExplicitlyExcludedFromProverFS(name) { + continue + } + + instance := run.GetColumn(col.GetColID()) + fs.UpdateSV(instance) + } + + queries := ctx.QueryParams[currRound-1] + for _, q := range queries { + params := run.GetParams(q.Name()) + params.UpdateFS(fs) + } + } + + toCompute := ctx.Coins[currRound] + for _, coin := range toCompute { + info := spec.Coins.Data(coin.Name) + value := info.Sample(fs) + run.InsertCoin(coin.Name, value) + } + + for _, fsHook := range ctx.FsHooks[currRound] { + fsHook.Run(wrappedRun) + } + + fsHistory[currRound] = [2][]field.Element{ + initialState, + fs.State(), + } + + wrappedRun.SetState(fiatShamirHistoryStr, fsHistory) + wrappedRun.SetState(fiatShamirTranscriptStr, fs) +} + +// Fs returns the Fiat-Shamir state +func (run *RuntimeWithReplacedFS) Fs() *fiatshamir.State { + return run.FS +} + +// FsHistory returns the Fiat-Shamir state history +func (run *RuntimeWithReplacedFS) FsHistory() [][2][]field.Element { + return run.FiatShamirHistory +} + +func (pa PreVortexVerifierStep) RunGnark(api frontend.API, run wizard.GnarkRuntime) { + + for _, ctx := range pa.Ctxs { + + pa.generateRandomCoinsGnark(api, run, ctx, pa.Round) + + // Wraps the runtime into a translation adapter + var wrappedRun = &gnarkRuntimeTranslator{Prefix: ctx.Translator.Prefix, Rt: run} + + // Copy the verifier actions of the template into the target + for _, va := range ctx.VerifierActions[pa.Round] { + va.RunGnark(api, wrappedRun) + } + } +} + +// generateRandomCoinsGnark generates all the coins for the current round +// so that they are made available to the forthcoming verifier actions. +func (pa PreVortexVerifierStep) generateRandomCoinsGnark(api frontend.API, run wizard.GnarkRuntime, ctx *recursionCtx, currRound int) { + + const ( + fiatShamirHistoryStr = "fiat-shamir-history" + fiatShamirTranscriptStr = "fiat-shamir-transcript" + ) + + var ( + spec = run.GetSpec() + // Wraps the runtime into a translation adapter, first to get the FS state + // and history. + wrappedRun wizard.GnarkRuntime = &gnarkRuntimeTranslator{Prefix: ctx.Translator.Prefix, Rt: run} + fsAny, _ = wrappedRun.GetState(fiatShamirTranscriptStr) + fsHistoryAny, _ = wrappedRun.GetState(fiatShamirHistoryStr) + fs = fsAny.(*fiatshamir.GnarkFiatShamir) + fsHistory = fsHistoryAny.([][2][]frontend.Variable) + initialState = fs.State() + ) + + // Wraps it a second time + wrappedRun = &GnarkRuntimeWithReplacedFS{ + GnarkRuntime: wrappedRun, + FS: fs, + FiatShamirHistory: fsHistory, + } + + if currRound > 0 { + + cols := ctx.Columns[currRound-1] + for _, col := range cols { + + name := unprefix(ctx.Translator.Prefix, col.GetColID()) + if ctx.Tmpl.Columns.IsExplicitlyExcludedFromProverFS(name) { + continue + } + + instance := run.GetColumn(col.GetColID()) + fs.UpdateVec(instance) + } + + queries := ctx.QueryParams[currRound-1] + for _, q := range queries { + params := run.GetParams(q.Name()) + params.UpdateFS(fs) + } + } + + toCompute := ctx.Coins[currRound] + for _, c := range toCompute { + info := spec.Coins.Data(c.Name) + switch info.Type { + case coin.Field: + value := fs.RandomField() + run.InsertCoin(c.Name, value) + case coin.IntegerVec: + value := fs.RandomManyIntegers(info.Size, info.UpperBound) + run.InsertCoin(c.Name, value) + } + } + + for _, fsHook := range ctx.FsHooks[currRound] { + fsHook.RunGnark(api, wrappedRun) + } + + fsHistory[currRound] = [2][]frontend.Variable{ + initialState, + fs.State(), + } + + wrappedRun.SetState(fiatShamirHistoryStr, fsHistory) + wrappedRun.SetState(fiatShamirTranscriptStr, fs) +} + +// Fs returns the Fiat-Shamir state +func (run *GnarkRuntimeWithReplacedFS) Fs() *fiatshamir.GnarkFiatShamir { + return run.FS +} + +// FsHistory returns the Fiat-Shamir state history +func (run *GnarkRuntimeWithReplacedFS) FsHistory() [][2][]frontend.Variable { + return run.FiatShamirHistory +} + +func (pa *PreVortexVerifierStep) IsSkipped() bool { + return pa.isSkipped +} + +func (pa *PreVortexVerifierStep) Skip() { + pa.isSkipped = true +} + +func (fs *FsJoinHook) Run(run wizard.Runtime) error { + + mainState := field.NewElement(0) + + for _, ctx := range fs.Ctxs { + + var ( + wrappedRun wizard.Runtime = &runtimeTranslator{Prefix: ctx.Translator.Prefix, Rt: run} + fsAny, _ = wrappedRun.GetState(fiatShamirTranscriptStr) + fs = fsAny.(*fiatshamir.State) + ) + + mainState = mimc.BlockCompression(mainState, fs.State()[0]) + } + + run.Fs().SetState([]field.Element{mainState}) + + return nil +} + +func (fs *FsJoinHook) RunGnark(api frontend.API, run wizard.GnarkRuntime) { + + mainState := frontend.Variable(0) + + for _, ctx := range fs.Ctxs { + + var ( + wrappedRun wizard.GnarkRuntime = &gnarkRuntimeTranslator{Prefix: ctx.Translator.Prefix, Rt: run} + fsAny, _ = wrappedRun.GetState(fiatShamirTranscriptStr) + fs = fsAny.(*fiatshamir.GnarkFiatShamir) + ) + + mainState = mimc.GnarkBlockCompression(api, mainState, fs.State()[0]) + } + + run.Fs().SetState([]frontend.Variable{mainState}) +} + +func (fs *FsJoinHook) Skip() { + fs.isSkipped = true +} + +func (fs *FsJoinHook) IsSkipped() bool { + return fs.isSkipped +} + +func (fs *SubFsInitialize) Run(run wizard.Runtime) error { + + for _, ctx := range fs.Ctxs { + + var ( + wrappedRun wizard.Runtime = &runtimeTranslator{Prefix: ctx.Translator.Prefix, Rt: run} + fs = fiatshamir.NewMiMCFiatShamir() + fsHistory = make([][2][]field.Element, ctx.LastRound+1) + ) + + fs.Update(ctx.Tmpl.FiatShamirSetup) + + wrappedRun.SetState(fiatShamirHistoryStr, fsHistory) + wrappedRun.SetState(fiatShamirTranscriptStr, fs) + } + + return nil +} + +func (fs *SubFsInitialize) RunGnark(api frontend.API, run wizard.GnarkRuntime) { + + for _, ctx := range fs.Ctxs { + + var ( + wrappedRun wizard.GnarkRuntime = &gnarkRuntimeTranslator{Prefix: ctx.Translator.Prefix, Rt: run} + fs = fiatshamir.NewGnarkFiatShamir(api, run.GetHasherFactory()) + fsHistory = make([][2][]frontend.Variable, ctx.LastRound+1) + ) + + fs.Update(api, ctx.Tmpl.FiatShamirSetup) + + wrappedRun.SetState(fiatShamirHistoryStr, fsHistory) + wrappedRun.SetState(fiatShamirTranscriptStr, fs) + } +} + +func (fs *SubFsInitialize) Skip() { + fs.isSkipped = true +} + +func (fs *SubFsInitialize) IsSkipped() bool { + return fs.isSkipped +} diff --git a/prover/protocol/distributed/constants/constant.go b/prover/protocol/distributed/constants/constant.go new file mode 100644 index 00000000000..0a35a13d59a --- /dev/null +++ b/prover/protocol/distributed/constants/constant.go @@ -0,0 +1,7 @@ +package constants + +const ( + LogDerivativeSumPublicInput = "LOG_DERIVATE_SUM_PUBLIC_INPUT" + GrandProductPublicInput = "GRAND_PRODUCT_PUBLIC_INPUT" + GrandSumPublicInput = "GRAND_SUM_PUBLIC_INPUT" +) diff --git a/prover/protocol/distributed/xcomp/xcomp.go b/prover/protocol/distributed/xcomp/xcomp.go deleted file mode 100644 index 0dde03eeb1b..00000000000 --- a/prover/protocol/distributed/xcomp/xcomp.go +++ /dev/null @@ -1,144 +0,0 @@ -package xcomp - -import ( - "fmt" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/linea-monorepo/prover/maths/field" - "github.com/consensys/linea-monorepo/prover/protocol/accessors" - "github.com/consensys/linea-monorepo/prover/protocol/wizard" - "github.com/consensys/linea-monorepo/prover/utils/collection" -) - -// GetCrossComp generates an (empty) compiledIOP object that is handling the crosse checks -// for example, the global sum over logDerivativeSum is zero. -func GetCrossComp(vRuntimes []*wizard.VerifierRuntime) *wizard.CompiledIOP { - - xComp := wizard.NewCompiledIOP() - - // initialize the PICollector - va := PublicInputCollector{ - PIFromLogDeriv: collection.NewMapping[string, field.Element](), - PIFromGrandProd: collection.NewMapping[string, field.Element](), - PIFromGrandSum: collection.NewMapping[string, field.Element](), - } - // get the publicInputs values from different verifiers. - for i, runtime := range vRuntimes { - va.Index = i - CollectPIfromVerifer(runtime, &va) - } - - // register a verifier action to check the consistency of public inputs - xComp.RegisterVerifierAction(0, &PublicInputChecker{PublicInputCollector: va}) - - return xComp -} - -// PublicInputCollector collects the public input values from different modules/segments. -type PublicInputCollector struct { - // maps for collecting publicInputs from different modules. - PIFromLogDeriv, PIFromGrandProd, PIFromGrandSum collection.Mapping[string, field.Element] - // index for the verifier from which we are receiving the publicInputs - Index int -} - -// CollectPIfromVerifer adds the public inputs of a given verifier to the Collector. -func CollectPIfromVerifer(run *wizard.VerifierRuntime, pic *PublicInputCollector) { - var ( - allPI = run.Spec.PublicInputs - ) - - for _, pi := range allPI { - - name := fmt.Sprintf("%v_%v", pi.Name, pic.Index) - - switch v := pi.Acc.(type) { - - case *accessors.FromLogDerivSumAccessor: - pic.PIFromLogDeriv.InsertNew(name, v.GetVal(run)) - - case *accessors.FromGrandProductAccessor: - pic.PIFromLogDeriv.InsertNew(name, v.GetVal(run)) - - case *accessors.FromGrandSumAccessor: - pic.PIFromLogDeriv.InsertNew(name, v.GetVal(run)) - } - } - -} - -type PublicInputChecker struct { - PublicInputCollector - skip bool -} - -// Run implements the [wizard.VerifierAction], it handles the cross checks over the public inputs. -// for example the global sum over the LogDerivativeSum from different segments should be zero. -func (pir *PublicInputChecker) Run(run *wizard.VerifierRuntime) error { - var ( - logDerivSum, grandSum field.Element - grandProduct = field.One() - ) - - for _, key := range pir.PIFromLogDeriv.ListAllKeys() { - curr := pir.PIFromLogDeriv.MustGet(key) - logDerivSum.Add(&logDerivSum, &curr) - } - for _, key := range pir.PIFromGrandProd.ListAllKeys() { - curr := pir.PIFromGrandProd.MustGet(key) - grandProduct.Add(&grandProduct, &curr) - } - for _, key := range pir.PIFromGrandSum.ListAllKeys() { - curr := pir.PIFromGrandSum.MustGet(key) - grandSum.Add(&grandSum, &curr) - } - - if logDerivSum != field.Zero() { - panic("the global sum over LogDerivSumParams is not zero," + - " maybe the same coin over different modules has different values") - } - - if grandProduct != field.One() { - panic("the global product overGrandProductParams is not 1," + - " maybe the same coin over different modules has different values") - } - - if grandSum != field.Zero() { - panic("the global sum over GrandSumParams is not zero," + - " maybe the same coin over different modules has different values") - } - return nil - -} - -// RunGnark implements the [wizard.VerifierAction] -func (pir *PublicInputChecker) RunGnark(api frontend.API, run *wizard.WizardVerifierCircuit) { - var ( - logDerivSum, grandProduct, grandSum frontend.API - ) - - for _, key := range pir.PIFromLogDeriv.ListAllKeys() { - curr := pir.PIFromLogDeriv.MustGet(key) - api.Add(logDerivSum, curr) - } - for _, key := range pir.PIFromGrandProd.ListAllKeys() { - curr := pir.PIFromGrandProd.MustGet(key) - api.Add(grandProduct, curr) - } - for _, key := range pir.PIFromGrandSum.ListAllKeys() { - curr := pir.PIFromGrandSum.MustGet(key) - api.Add(grandSum, curr) - } - - api.AssertIsEqual(logDerivSum, field.Zero()) - api.AssertIsEqual(grandProduct, field.One()) - api.AssertIsEqual(grandSum, field.Zero()) -} - -func (v *PublicInputChecker) Skip() { - v.skip = true -} - -func (v *PublicInputChecker) IsSkipped() bool { - return v.skip -} diff --git a/prover/protocol/wizard/actions.go b/prover/protocol/wizard/actions.go index 398017fba90..bf59a2e0c80 100644 --- a/prover/protocol/wizard/actions.go +++ b/prover/protocol/wizard/actions.go @@ -20,24 +20,24 @@ type VerifierAction interface { IsSkipped() bool // Run executes the VerifierAction over a [VerifierRuntime] it returns an // error. - Run(*VerifierRuntime) error + Run(Runtime) error // RunGnark is as Run but in a gnark circuit. Instead, of the returning an // error the function enforces the passing of the verifier's checks. - RunGnark(frontend.API, *WizardVerifierCircuit) + RunGnark(frontend.API, GnarkRuntime) } // genVerifierAction represents a verifier action represented by closures type genVerifierAction struct { skipped bool - run func(*VerifierRuntime) error - runGnark func(frontend.API, *WizardVerifierCircuit) + run func(Runtime) error + runGnark func(frontend.API, GnarkRuntime) } -func (gva *genVerifierAction) Run(run *VerifierRuntime) error { +func (gva *genVerifierAction) Run(run Runtime) error { return gva.run(run) } -func (gva *genVerifierAction) RunGnark(api frontend.API, run *WizardVerifierCircuit) { +func (gva *genVerifierAction) RunGnark(api frontend.API, run GnarkRuntime) { gva.runGnark(api, run) } diff --git a/prover/protocol/wizard/builder.go b/prover/protocol/wizard/builder.go index 7b0c3f1d74a..b575730c42f 100644 --- a/prover/protocol/wizard/builder.go +++ b/prover/protocol/wizard/builder.go @@ -315,4 +315,12 @@ func (comp *CompiledIOP) EqualizeRounds(numRounds int) { utils.Panic("Bug : numRounds is %v but %v rounds are registered for the verifier. %v", numRounds, comp.SubVerifiers.Len(), helpMsg) } comp.SubVerifiers.Reserve(numRounds) + + /* + Check and reserve for the FiatShamirHooks + */ + if comp.FiatShamirHooks.Len() > numRounds { + utils.Panic("Bug : numRounds is %v but %v rounds are registered for the FiatShamirHooks. %v", numRounds, comp.FiatShamirHooks.Len(), helpMsg) + } + comp.FiatShamirHooks.Reserve(numRounds) } diff --git a/prover/protocol/wizard/compiled.go b/prover/protocol/wizard/compiled.go index b24f861bd24..65f11771a88 100644 --- a/prover/protocol/wizard/compiled.go +++ b/prover/protocol/wizard/compiled.go @@ -111,13 +111,13 @@ type CompiledIOP struct { // process. An artefact must satisfy the io.ReadWriteTo interface. Artefacts artefactCache - // fiatShamirSetup stores an initial value to use to bootstrap the Fiat-Shamir + // FiatShamirSetup stores an initial value to use to bootstrap the Fiat-Shamir // transcript. This value is obtained by hashing diverse meta-data of the // describing the wizard: a version number, the description of the field, // a description of all the columns and all the queries etc... // - // For efficiency reasons, the fiatShamirSetup is derived using SHA2. - fiatShamirSetup field.Element + // For efficiency reasons, the FiatShamirSetup is derived using SHA2. + FiatShamirSetup field.Element // FunctionalPublic inputs lists the queries representing a public inputs // and their identifiers @@ -174,8 +174,6 @@ func (c *CompiledIOP) InsertColumn(round int, name ifaces.ColID, size int, statu utils.Panic("column %v has size %v", name, size) } - c.assertConsistentRound(round) - if len(name) == 0 { panic("Column with an empty name") } @@ -220,8 +218,6 @@ func (c *CompiledIOP) InsertCoin(round int, name coin.Name, type_ coin.Type, siz // - the definition round is inconsistent with the expression func (c *CompiledIOP) InsertGlobal(round int, name ifaces.QueryID, expr *symbolic.Expression, noBoundCancel ...bool) query.GlobalConstraint { - c.assertConsistentRound(round) - // The constructor of the global constraint is assumed to perform all the // well-formation checks of the constraint. cs := query.NewGlobalConstraint(name, expr, noBoundCancel...) @@ -264,8 +260,6 @@ func (c *CompiledIOP) InsertGlobal(round int, name ifaces.QueryID, expr *symboli // - the definition round is inconsistent with the expression func (c *CompiledIOP) InsertLocal(round int, name ifaces.QueryID, cs_ *symbolic.Expression) query.LocalConstraint { - c.assertConsistentRound(round) - cs := query.NewLocalConstraint(name, cs_) boarded := cs.Board() metadatas := boarded.ListVariableMetadata() @@ -300,7 +294,6 @@ func (c *CompiledIOP) InsertLocal(round int, name ifaces.QueryID, cs_ *symbolic. // - any column in `a` or `b“ is a not registered columns // - a constraint with the same name already exists in the CompiledIOP func (c *CompiledIOP) InsertPermutation(round int, name ifaces.QueryID, a, b []ifaces.Column) query.Permutation { - c.assertConsistentRound(round) query_ := query.NewPermutation(name, [][]ifaces.Column{a}, [][]ifaces.Column{b}) c.QueriesNoParams.AddToRound(round, name, query_) return query_ @@ -310,7 +303,6 @@ func (c *CompiledIOP) InsertPermutation(round int, name ifaces.QueryID, a, b []i // fragmented tables. Meanining that permutation operates over the union of // the rows of multiple tables. func (c *CompiledIOP) InsertFragmentedPermutation(round int, name ifaces.QueryID, a, b [][]ifaces.Column) query.Permutation { - c.assertConsistentRound(round) query_ := query.NewPermutation(name, a, b) c.QueriesNoParams.AddToRound(round, name, query_) return query_ @@ -343,7 +335,6 @@ func (c *CompiledIOP) InsertFixedPermutation(round int, name ifaces.QueryID, p [ // - the columns in `included` do not all have the same size // - a constraint with the same name already exists in the CompiledIOP func (c *CompiledIOP) InsertInclusion(round int, name ifaces.QueryID, including, included []ifaces.Column) { - c.assertConsistentRound(round) query := query.NewInclusion(name, included, [][]ifaces.Column{including}, nil, nil) c.QueriesNoParams.AddToRound(round, name, query) } @@ -353,7 +344,6 @@ Creates an inclusion query. Both the including and the included tables are filte the filters should be columns containing only field elements for 0 and 1 */ func (c *CompiledIOP) InsertInclusionDoubleConditional(round int, name ifaces.QueryID, including, included []ifaces.Column, includingFilter, includedFilter ifaces.Column) { - c.assertConsistentRound(round) query := query.NewInclusion(name, included, [][]ifaces.Column{including}, includedFilter, []ifaces.Column{includingFilter}) c.QueriesNoParams.AddToRound(round, name, query) } @@ -363,7 +353,6 @@ Creates an inclusion query. Only the including table is filtered the filters should be columns containing only field elements for 0 and 1 */ func (c *CompiledIOP) InsertInclusionConditionalOnIncluding(round int, name ifaces.QueryID, including, included []ifaces.Column, includingFilter ifaces.Column) { - c.assertConsistentRound(round) query := query.NewInclusion(name, included, [][]ifaces.Column{including}, nil, []ifaces.Column{includingFilter}) c.QueriesNoParams.AddToRound(round, name, query) } @@ -373,7 +362,6 @@ Creates an inclusion query. Only the included table is filtered the filters should be columns containing only field elements for 0 and 1 */ func (c *CompiledIOP) InsertInclusionConditionalOnIncluded(round int, name ifaces.QueryID, including, included []ifaces.Column, includedFilter ifaces.Column) { - c.assertConsistentRound(round) query := query.NewInclusion(name, included, [][]ifaces.Column{including}, includedFilter, nil) c.QueriesNoParams.AddToRound(round, name, query) } @@ -394,7 +382,6 @@ func (c *CompiledIOP) GenericFragmentedConditionalInclusion( includingFilter []ifaces.Column, includedFilter ifaces.Column, ) { - c.assertConsistentRound(round) query := query.NewInclusion(name, included, including, includedFilter, includingFilter) c.QueriesNoParams.AddToRound(round, name, query) } @@ -437,7 +424,6 @@ func (c *CompiledIOP) InsertPrecomputed(name ifaces.ColID, v smartvectors.SmartV // // The name must be non-empty and unique and the size must be a power of 2. func (c *CompiledIOP) InsertProof(round int, name ifaces.ColID, size int) (msg ifaces.Column) { - c.assertConsistentRound(round) // Common : No zero length if size == 0 { @@ -447,21 +433,6 @@ func (c *CompiledIOP) InsertProof(round int, name ifaces.ColID, size int) (msg i return c.Columns.AddToRound(round, name, size, column.Proof) } -// InsertPublicInput registers a public input column, and specifies static information regarding it - -// Deprecated: we never really use this type of column to denote actual public -// inputs. The plan is to resort to using [query.LocalOpeningParams] instead. -func (c *CompiledIOP) InsertPublicInput(round int, name ifaces.ColID, size int) (msg ifaces.Column) { - c.assertConsistentRound(round) - - // Common : No zero length - if size == 0 { - utils.Panic("when registering %v, VecType with length zero", name) - } - - return c.Columns.AddToRound(round, name, size, column.PublicInput) -} - // InsertVerifier registers a verifier steps into the current CompiledIOP; // meaning a "native" Go function that performs one or more checks involving // wizard items that are accessible to the verifier of the specified protocol. @@ -479,7 +450,6 @@ func (c *CompiledIOP) InsertPublicInput(round int, name ifaces.ColID, size int) // not intend to run the verifier of the Wizard protocol in a gnark circuit, // passing `nil` is fine. func (c *CompiledIOP) InsertVerifier(round int, ver VerifierStep, gnarkVer GnarkVerifierStep) { - c.assertConsistentRound(round) c.SubVerifiers.AppendToInner(round, &genVerifierAction{ run: ver, runGnark: gnarkVer, @@ -510,7 +480,6 @@ func (c *CompiledIOP) InsertRange(round int, name ifaces.QueryID, h ifaces.Colum panic("max is zero : perhaps an overflow") } - c.assertConsistentRound(round) /* In case the range is applied over a composite handle. We apply the range over each natural component of the handle. @@ -531,7 +500,6 @@ func (c *CompiledIOP) InsertRange(round int, name ifaces.QueryID, h ifaces.Colum // - a query with the same name has already been registered in the Wizard // - the provided columns `a` and `bs` do not all have the same size func (c *CompiledIOP) InsertInnerProduct(round int, name ifaces.QueryID, a ifaces.Column, bs []ifaces.Column) query.InnerProduct { - c.assertConsistentRound(round) // Also ensures that the query round does not predates the columns rounds maxComRound := a.Round() @@ -565,7 +533,6 @@ func (run *CompiledIOP) GetInnerProduct(name ifaces.QueryID) query.InnerProduct // - the name is the empty string // - a query with the same name has already been registered in the Wizard func (c *CompiledIOP) InsertUnivariate(round int, name ifaces.QueryID, pols []ifaces.Column) query.UnivariateEval { - c.assertConsistentRound(round) q := query.NewUnivariateEval(name, pols...) // Finally registers the query c.QueriesParams.AddToRound(round, name, q) @@ -576,7 +543,6 @@ func (c *CompiledIOP) InsertUnivariate(round int, name ifaces.QueryID, pols []if // in the current CompiledIOP. A local opening query requires the prover of the // protocol to "open" the first position of the vector. func (c *CompiledIOP) InsertLocalOpening(round int, name ifaces.QueryID, pol ifaces.Column) query.LocalOpening { - c.assertConsistentRound(round) q := query.NewLocalOpening(name, pol) // Finally registers the query c.QueriesParams.AddToRound(round, name, q) @@ -587,21 +553,12 @@ func (c *CompiledIOP) InsertLocalOpening(round int, name ifaces.QueryID, pol ifa // It generates a single global summation for many Sigma Columns from Lookup compilation. // The sigma columns are categorized by [round,size]. func (c *CompiledIOP) InsertLogDerivativeSum(lastRound int, id ifaces.QueryID, in map[int]*query.LogDerivativeSumInput) query.LogDerivativeSum { - c.assertConsistentRound(lastRound) q := query.NewLogDerivativeSum(lastRound, in, id) // Finally registers the query c.QueriesParams.AddToRound(lastRound, id, q) return q } -// assertConsistentRound compares the round passed as an argument and panic if it greater than -// coin.Round. This helps ensuring that we do not have "useless" rounds. -func (c *CompiledIOP) assertConsistentRound(round int) { - if round > c.Coins.NumRounds() { - utils.Panic("Inserted at round %v, but the max should be %v", round, c.Coins.NumRounds()) - } -} - // InsertMiMC declares a MiMC constraints query; a constraint that all the // entries of new are obtained by running the compression function of MiMC over // the entries of block and old, row-by-row. @@ -611,7 +568,6 @@ func (c *CompiledIOP) assertConsistentRound(round int) { // - the declaration round is anterior to the declaration round of the // provided input columns. func (c *CompiledIOP) InsertMiMC(round int, id ifaces.QueryID, block, old, new ifaces.Column) query.MiMC { - c.assertConsistentRound(round) q := query.NewMiMC(id, block, old, new) c.QueriesNoParams.AddToRound(round, id, q) return q @@ -647,7 +603,6 @@ func (c *CompiledIOP) RegisterVerifierAction(round int, action VerifierAction) { // Register a GrandProduct query func (c *CompiledIOP) InsertGrandProduct(round int, id ifaces.QueryID, in map[int]*query.GrandProductInput) query.GrandProduct { - c.assertConsistentRound(round) q := query.NewGrandProduct(round, in, id) // Finally registers the query c.QueriesParams.AddToRound(round, q.Name(), q) @@ -685,9 +640,33 @@ func (c *CompiledIOP) InsertProjection(id ifaces.QueryID, in query.ProjectionInp in.FilterA.Round(), in.FilterB.Round()) ) - c.assertConsistentRound(round) q := query.NewProjection(round, id, in) // Finally registers the query c.QueriesNoParams.AddToRound(round, q.Name(), q) return q } + +// AddPublicInput inserts a public-input in the compiled-IOP +func (c *CompiledIOP) InsertPublicInput(name string, acc ifaces.Accessor) PublicInput { + + res := PublicInput{ + Name: name, + Acc: acc, + } + + c.PublicInputs = append(c.PublicInputs, res) + return res +} + +// GetPublicInputAccessor attempts to find a public input with the provided name +// and panic if it fails to do so. The method returns the accessor in case of +// success. +func (c *CompiledIOP) GetPublicInputAccessor(name string) ifaces.Accessor { + for _, pi := range c.PublicInputs { + if pi.Name == name { + return pi.Acc + } + } + utils.Panic("could not find public input %v", name) + return nil // unreachable +} diff --git a/prover/protocol/wizard/fiatshamir.go b/prover/protocol/wizard/fiatshamir.go index 59ecbaa7e72..109d92438f5 100644 --- a/prover/protocol/wizard/fiatshamir.go +++ b/prover/protocol/wizard/fiatshamir.go @@ -34,7 +34,7 @@ func (comp *CompiledIOP) BootstrapFiatShamir(vm VersionMetadata, ser CompiledIOP // hasher.Write(compBlob) digest := hasher.Sum(nil) digest[0] = 0 // This is to prevent potential errors due to overflowing the field - comp.fiatShamirSetup.SetBytes(digest) + comp.FiatShamirSetup.SetBytes(digest) return comp } diff --git a/prover/protocol/wizard/gnark_verifier.go b/prover/protocol/wizard/gnark_verifier.go index a07ce74cb58..fc2c1292f4e 100644 --- a/prover/protocol/wizard/gnark_verifier.go +++ b/prover/protocol/wizard/gnark_verifier.go @@ -16,9 +16,30 @@ import ( "github.com/sirupsen/logrus" ) +// GnarkRuntime is the interface implemented by the struct [WizardVerifierCircuit] +// and is used to interact with the GnarkVerifierStep. +type GnarkRuntime interface { + ifaces.GnarkRuntime + GetSpec() *CompiledIOP + GetPublicInput(api frontend.API, name string) frontend.Variable + GetGrandProductParams(name ifaces.QueryID) query.GnarkGrandProductParams + GetLogDerivSumParams(name ifaces.QueryID) query.GnarkLogDerivSumParams + GetLocalPointEvalParams(name ifaces.QueryID) query.GnarkLocalOpeningParams + GetInnerProductParams(name ifaces.QueryID) query.GnarkInnerProductParams + GetUnivariateEval(name ifaces.QueryID) query.UnivariateEval + GetUnivariateParams(name ifaces.QueryID) query.GnarkUnivariateEvalParams + Fs() *fiatshamir.GnarkFiatShamir + FsHistory() [][2][]frontend.Variable + GetHasherFactory() *gkrmimc.HasherFactory + InsertCoin(name coin.Name, value interface{}) + GetState(name string) (any, bool) + SetState(name string, value any) + GetQuery(name ifaces.QueryID) ifaces.Query +} + // GnarkVerifierStep functions that can be registered in the CompiledIOP by the successive // compilation steps. They correspond to "precompiled" verification steps. -type GnarkVerifierStep func(frontend.API, *WizardVerifierCircuit) +type GnarkVerifierStep func(frontend.API, GnarkRuntime) // WizardVerifierCircuit the [VerifierRuntime] in a gnark circuit. The complete // implementation follows this mirror logic. @@ -96,6 +117,10 @@ type WizardVerifierCircuit struct { // round. The first entry is the initial state, the final entry is the final // state. FiatShamirHistory [][2][]frontend.Variable `gnark:"-"` + + // State is a generic-purpose data store that the verifier steps can use to + // communicate with each other across rounds. + State map[string]interface{} `gnark:"-"` } // AllocateWizardCircuit allocates the inner-slices of the verifier struct from a precompiled IOP. It @@ -161,7 +186,7 @@ func AllocateWizardCircuit(comp *CompiledIOP) (*WizardVerifierCircuit, error) { func (c *WizardVerifierCircuit) Verify(api frontend.API) { c.HasherFactory = gkrmimc.NewHasherFactory(api) c.FS = fiatshamir.NewGnarkFiatShamir(api, c.HasherFactory) - c.FS.Update(c.Spec.fiatShamirSetup) + c.FS.Update(c.Spec.FiatShamirSetup) c.FiatShamirHistory = make([][2][]frontend.Variable, c.Spec.NumRounds()) c.generateAllRandomCoins(api) @@ -190,12 +215,11 @@ func (c *WizardVerifierCircuit) generateAllRandomCoins(api frontend.API) { // the last one to "talk" in the protocol. toUpdateFS := c.Spec.Columns.AllKeysProofAt(currRound - 1) for _, msg := range toUpdateFS { - msgContent := c.GetColumn(msg) - c.FS.UpdateVec(msgContent) - } - toUpdateFS = c.Spec.Columns.AllKeysPublicInputAt(currRound - 1) - for _, msg := range toUpdateFS { + if c.Spec.Columns.IsExplicitlyExcludedFromProverFS(msg) { + continue + } + msgContent := c.GetColumn(msg) c.FS.UpdateVec(msgContent) } @@ -554,3 +578,63 @@ func (c *WizardVerifierCircuit) GetPublicInput(api frontend.API, name string) fr utils.Panic("could not find public input nb %v", name) return field.Element{} } + +// Fs returns the Fiat-Shamir state of the verifier circuit +func (c *WizardVerifierCircuit) Fs() *fiatshamir.GnarkFiatShamir { + return c.FS +} + +// FsHistory returns the Fiat-Shamir state history of the verifier circuit +func (c *WizardVerifierCircuit) FsHistory() [][2][]frontend.Variable { + return c.FiatShamirHistory +} + +// SetFs sets the Fiat-Shamir state of the verifier circuit +func (c *WizardVerifierCircuit) SetFs(fs *fiatshamir.GnarkFiatShamir) { + c.FS = fs +} + +// GetHasherFactory returns the hasher factory of the verifier circuit; nil +// if none is set. +func (c *WizardVerifierCircuit) GetHasherFactory() *gkrmimc.HasherFactory { + return c.HasherFactory +} + +// SetHasherFactory sets the hasher factory of the verifier circuit +func (c *WizardVerifierCircuit) SetHasherFactory(hf *gkrmimc.HasherFactory) { + c.HasherFactory = hf +} + +// GetSpec returns the compiled IOP of the verifier circuit +func (c *WizardVerifierCircuit) GetSpec() *CompiledIOP { + return c.Spec +} + +// InsertCoin inserts a coin in the verifier circuit. This has +// a use for implementing recursive application. +func (c *WizardVerifierCircuit) InsertCoin(name coin.Name, value interface{}) { + c.Coins.InsertNew(name, value) +} + +// GetState returns the value of a state variable in the verifier circuit +func (c *WizardVerifierCircuit) GetState(name string) (any, bool) { + res, ok := c.State[name] + return res, ok +} + +// SetState sets the value of a state variable in the verifier circuit +func (c *WizardVerifierCircuit) SetState(name string, value any) { + c.State[name] = value +} + +// GetQuery returns a query from its name +func (c *WizardVerifierCircuit) GetQuery(name ifaces.QueryID) ifaces.Query { + if c.Spec.QueriesParams.Exists(name) { + return c.Spec.QueriesParams.Data(name) + } + if c.Spec.QueriesNoParams.Exists(name) { + return c.Spec.QueriesNoParams.Data(name) + } + utils.Panic("could not find query nb %v", name) + return nil +} diff --git a/prover/protocol/wizard/prover.go b/prover/protocol/wizard/prover.go index 7a682c05d7d..5b1aa151e50 100644 --- a/prover/protocol/wizard/prover.go +++ b/prover/protocol/wizard/prover.go @@ -158,29 +158,7 @@ type ProverRuntime struct { // when the specified protocol is complicated and involves multiple multi-rounds // sub-protocols that runs independently. func Prove(c *CompiledIOP, highLevelprover ProverStep) Proof { - - runtime := RunProver(c, highLevelprover) - - /* - Pass all the prover message columns as part of the proof - */ - messages := collection.NewMapping[ifaces.ColID, ifaces.ColAssignment]() - - for _, name := range runtime.Spec.Columns.AllKeysProof() { - messageValue := runtime.Columns.MustGet(name) - messages.InsertNew(name, messageValue) - } - - // And also the public inputs - for _, name := range runtime.Spec.Columns.AllKeysPublicInput() { - messageValue := runtime.Columns.MustGet(name) - messages.InsertNew(name, messageValue) - } - - return Proof{ - Messages: messages, - QueriesParams: runtime.QueriesParams, - } + return RunProver(c, highLevelprover).ExtractProof() } // RunProver initializes a [ProverRuntime], runs the prover and returns the final @@ -207,6 +185,22 @@ func RunProver(c *CompiledIOP, highLevelprover ProverStep) *ProverRuntime { return &runtime } +// RunProverUntilRound runs the prover until the specified round +func RunProverUntilRound(c *CompiledIOP, highLevelprover ProverStep, round int) *ProverRuntime { + + runtime := c.createProver() + + highLevelprover(&runtime) + runtime.runProverSteps() + + for runtime.currRound+1 < round { + runtime.goNextRound() + runtime.runProverSteps() + } + + return &runtime +} + // ProveOnlyFirstRound computes the first round of the prover and returns the // resulting ProverRuntime containing all the generated assignments. func ProverOnlyFirstRound(c *CompiledIOP, highLevelprover ProverStep) *ProverRuntime { @@ -226,6 +220,37 @@ func ProverOnlyFirstRound(c *CompiledIOP, highLevelprover ProverStep) *ProverRun return &runtime } +// ExtractProof extracts the proof from a [ProverRuntime]. If the runtime has +// been obtained via a [RunProverUntilRound], then it may be the case that +// some columns have not been assigned at all. Those won't be included in the +// returned proof. +func (run *ProverRuntime) ExtractProof() Proof { + messages := collection.NewMapping[ifaces.ColID, ifaces.ColAssignment]() + + for _, name := range run.Spec.Columns.AllKeysProof() { + + cols := run.Spec.Columns.GetHandle(name) + if run.currRound < cols.Round() { + continue + } + + messageValue := run.Columns.MustGet(name) + messages.InsertNew(name, messageValue) + } + + queriesParams := collection.NewMapping[ifaces.QueryID, ifaces.QueryParams]() + for round := 0; round <= run.currRound; round++ { + for _, name := range run.Spec.QueriesParams.AllKeysAt(round) { + queriesParams.InsertNew(name, run.QueriesParams.MustGet(name)) + } + } + + return Proof{ + Messages: messages, + QueriesParams: queriesParams, + } +} + // NumRounds returns the total number of rounds in the corresponding WizardIOP. // // Deprecated: this method does not bring anything useful as its already easy @@ -244,7 +269,7 @@ func (c *CompiledIOP) createProver() ProverRuntime { // Create a new fresh FS state and bootstrap it fs := fiatshamir.NewMiMCFiatShamir() - fs.Update(c.fiatShamirSetup) + fs.Update(c.FiatShamirSetup) // Instantiates an empty Assignment (but link it to the CompiledIOP) runtime := ProverRuntime{ @@ -516,27 +541,6 @@ func (run *ProverRuntime) goNextRound() { initialState := run.FS.State() - /* - Make sure all issued random coin have been "consumed" by all the prover - steps, in the round we are closing. An error occuring here is more likely - an error in the compiler than an error from the user because it is not - responsible for setting the coin. Thus, this is more a sanity check. - */ - toBeConsumed := run.Spec.Coins.AllKeysAt(run.currRound) - run.Coins.MustExists(toBeConsumed...) - - /* - We do not make this check for the columns, the reason is that we delete - the columns that we do not use anymore. - */ - - /* - Then, make sure all the query parameters have been set - during the rounds we are closing - */ - toBeParametrized := run.Spec.QueriesParams.AllKeysAt(run.currRound) - run.QueriesParams.MustExists(toBeParametrized...) - if !run.Spec.DummyCompiled { /* @@ -545,20 +549,17 @@ func (run *ProverRuntime) goNextRound() { FS using the last round of the prover because he is always the last one to "talk" in the protocol. */ - msgsToFS := run.Spec.Columns.AllKeysProofsOrIgnoredButKeptInProverTranscript(run.currRound) + msgsToFS := run.Spec.Columns.AllKeysInProverTranscript(run.currRound) for _, msgName := range msgsToFS { - instance := run.GetMessage(msgName) - run.FS.UpdateSV(instance) - } - /* - Make sure that all messages have been written and use them - to update the FS state. Note that we do not need to update - FS using the last round of the prover because he is always - the last one to "talk" in the protocol. - */ - msgsToFS = run.Spec.Columns.AllKeysPublicInputAt(run.currRound) - for _, msgName := range msgsToFS { + if run.Spec.Columns.IsExplicitlyExcludedFromProverFS(msgName) { + continue + } + + if run.Spec.Precomputed.Exists(msgName) { + continue + } + instance := run.GetMessage(msgName) run.FS.UpdateSV(instance) } @@ -590,11 +591,16 @@ func (run *ProverRuntime) goNextRound() { toCompute := run.Spec.Coins.AllKeysAt(run.currRound) for _, myCoin := range toCompute { + var ( info = run.Spec.Coins.Data(myCoin) value interface{} ) + if run.Spec.Coins.IsSkippedFromProverTranscript(info.Name) { + continue + } + if info.Type == coin.FieldFromSeed { // if it is of type FromSeed, sample a coin based on the seed if seed, ok := run.ParentRuntime.Coins.MustGet("SEED").(field.Element); ok { diff --git a/prover/protocol/wizard/verifier.go b/prover/protocol/wizard/verifier.go index f53ac85a7fe..aaacfe165fb 100644 --- a/prover/protocol/wizard/verifier.go +++ b/prover/protocol/wizard/verifier.go @@ -33,11 +33,32 @@ type Proof struct { QueriesParams collection.Mapping[ifaces.QueryID, ifaces.QueryParams] } +// Runtime is a generic interface extending the [ifaces.Runtime] interface +// with all methods of [wizard.VerifierRuntime]. This is used to allow the +// writing of adapters for the verifier runtime. +type Runtime interface { + ifaces.Runtime + GetSpec() *CompiledIOP + GetPublicInput(name string) field.Element + GetGrandProductParams(name ifaces.QueryID) query.GrandProductParams + GetLogDerivSumParams(name ifaces.QueryID) query.LogDerivSumParams + GetLocalPointEvalParams(name ifaces.QueryID) query.LocalOpeningParams + GetInnerProductParams(name ifaces.QueryID) query.InnerProductParams + GetUnivariateEval(name ifaces.QueryID) query.UnivariateEval + GetUnivariateParams(name ifaces.QueryID) query.UnivariateEvalParams + GetQuery(name ifaces.QueryID) ifaces.Query + Fs() *fiatshamir.State + FsHistory() [][2][]field.Element + InsertCoin(name coin.Name, value any) + GetState(name string) (any, bool) + SetState(name string, value any) +} + // VerifierStep specifies a single step of verifier for a single subprotocol. // This can be used to specify verifier checks involving user-provided // columns for relations that cannot be automatically enforced via a // [ifaces.Query] -type VerifierStep func(a *VerifierRuntime) error +type VerifierStep func(a Runtime) error // VerifierRuntime runtime collects all data that visible or computed by the // verifier of the wizard protocol. This includes the prover's messages, the @@ -73,6 +94,11 @@ type VerifierRuntime struct { // round. The first entry is the initial state, the final entry is the final // state. FiatShamirHistory [][2][]field.Element + + // State stores arbitrary data that can be used by the verifier. This + // can be used to communicate values between verifier states. + State map[string]interface{} + // the parent run time; used in the distributed setting, // to provide access to the seed generated by the parent. ParentRuntime *VerifierRuntime @@ -170,9 +196,10 @@ func (c *CompiledIOP) createVerifier(proof Proof) VerifierRuntime { QueriesParams: proof.QueriesParams, FS: fiatshamir.NewMiMCFiatShamir(), FiatShamirHistory: make([][2][]field.Element, c.NumRounds()), + State: make(map[string]interface{}), } - runtime.FS.Update(c.fiatShamirSetup) + runtime.FS.Update(c.FiatShamirSetup) /* Insert the verifying key into the messages @@ -209,12 +236,11 @@ func (run *VerifierRuntime) generateAllRandomCoins() { */ msgsToFS := run.Spec.Columns.AllKeysProofAt(currRound - 1) for _, msgName := range msgsToFS { - instance := run.GetColumn(msgName) - run.FS.UpdateSV(instance) - } - msgsToFS = run.Spec.Columns.AllKeysPublicInputAt(currRound - 1) - for _, msgName := range msgsToFS { + if run.Spec.Columns.IsExplicitlyExcludedFromProverFS(msgName) { + continue + } + instance := run.GetColumn(msgName) run.FS.UpdateSV(instance) } @@ -240,6 +266,7 @@ func (run *VerifierRuntime) generateAllRandomCoins() { */ toCompute := run.Spec.Coins.AllKeysAt(currRound) for _, myCoin := range toCompute { + if run.Spec.Coins.IsSkippedFromVerifierTranscript(myCoin) { continue } @@ -468,3 +495,51 @@ func (run *VerifierRuntime) GetPublicInput(name string) field.Element { utils.Panic("could not find public input nb %v", name) return field.Element{} } + +// Fs returns the Fiat-Shamir state +func (run *VerifierRuntime) Fs() *fiatshamir.State { + return run.FS +} + +// FsHistory returns the Fiat-Shamir state history +func (run *VerifierRuntime) FsHistory() [][2][]field.Element { + return run.FiatShamirHistory +} + +// GetSpec returns the compiled IOP +func (run *VerifierRuntime) GetSpec() *CompiledIOP { + return run.Spec +} + +// InsertCoin inserts a coin into the runtime. It should not be +// used by usual verifier action but is useful when implementing +// recursion utilities. +func (run *VerifierRuntime) InsertCoin(name coin.Name, value any) { + run.Coins.InsertNew(name, value) +} + +// GetState returns an arbitrary value stored in the runtime +func (run *VerifierRuntime) GetState(name string) (any, bool) { + res, ok := run.State[name] + return res, ok +} + +// SetState sets an arbitrary value in the runtime +func (run *VerifierRuntime) SetState(name string, value any) { + run.State[name] = value +} + +// GetQuery returns a query from its name +func (run *VerifierRuntime) GetQuery(name ifaces.QueryID) ifaces.Query { + + if run.Spec.QueriesParams.Exists(name) { + return run.Spec.QueriesParams.Data(name) + } + + if run.Spec.QueriesNoParams.Exists(name) { + return run.Spec.QueriesNoParams.Data(name) + } + + utils.Panic("could not find query nb %v", name) + return nil +}